diff --git a/.gitignore b/.gitignore index 9ae0d9c96f188bc6357832f22b4125694302b104..be75938ec401b1d72fa54773c85191aaac7d7f35 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ node_modules /bazel-* /bazel_pip /tools/python_bin_path.sh -/tools/git/gen +/tensorflow/tools/git/gen /pip_test /_python_build *.pyc @@ -22,3 +22,15 @@ Pods Podfile.lock *.pbxproj *.xcworkspacedata +/tensorflow/contrib/lite/downloads/** +/tensorflow/contrib/lite/gen/** +/tensorflow/contrib/lite/examples/ios/simple/data/*.txt +/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite +xcuserdata/** + +# Android +.gradle +.idea +*.iml +local.properties +gradleBuild diff --git a/AUTHORS b/AUTHORS index a46ae7e616ab3a420d9fb2691ee8d8650032a39f..aa4be5169dcc68c579863e8ba6307cd00e9f9a68 100644 --- a/AUTHORS +++ b/AUTHORS @@ -7,4 +7,4 @@ # The email address is not required for organizations. Google Inc. -Yuan Tang terrytangyuan@gmail.com +Yuan Tang diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index cfc45049f7088e95059d2e07d5c8ce98f32def93..ff11d131409b65880f16b80f9fe38dc39ac0e5fa 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -55,14 +55,14 @@ If you are experiencing or witnessing conflict, we ask you to use the following ## Reporting Violations -Violations of the Code of Conduct can be reported to TensorFlow’s Project Steward at conduct@tensorflow.org. The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report. +Violations of the Code of Conduct can be reported to TensorFlow’s Project Stewards, Edd Wilder-James (ewj@google.com) and Sarah Novotny (sarahnovotny@google.com). The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report. Violations of the Code of Conduct can occur in any setting, even those unrelated to the project. We will only consider complaints about conduct that has occurred within one year of the report. ## Enforcement -If the Project Steward receives a report alleging a violation of the Code of Conduct, the Project Steward will notify the accused of the report, and provide them an opportunity to discuss the report before a sanction is issued. The Project Steward will do their utmost to keep the reporter anonymous. If the act is ongoing (such as someone engaging in harassment), or involves a threat to anyone's safety (e.g. threats of violence), the Project Steward may issue sanctions without notice. +If the Project Stewards receive a report alleging a violation of the Code of Conduct, the Project Stewards will notify the accused of the report, and provide them an opportunity to discuss the report before a sanction is issued. The Project Stewards will do their utmost to keep the reporter anonymous. If the act is ongoing (such as someone engaging in harassment), or involves a threat to anyone's safety (e.g. threats of violence), the Project Stewards may issue sanctions without notice. ## Attribution diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 43abdaafbf45379430920cd027b26299cd62553b..1b537ca73cc94e992e7537fe69c8d0cc8fd13102 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -114,6 +114,7 @@ pylint --rcfile=/tmp/pylintrc myfile.py * [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html) * [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html) * [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml) +* [Google Objective-C Style Guide](http://google.github.io/styleguide/objcguide.html) #### Running sanity check diff --git a/RELEASE.md b/RELEASE.md index d8db1f72004b5d944e3035a0f33dfc34a674b7ee..e04bd3fc505d51ade9e9fa12c822cb695e90b4f3 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -494,7 +494,7 @@ answered questions, and were part of inspiring discussions. This release contains contributions from many people at Google, as well as: A. Besir Kurtulmus, Adal Chiriliuc, @akash, Alec-Desouza, Alex Rothberg, Alex -Sergeev, Alexander Heinecke, Allen Guo, Andreas Madsen, Ankesh Anand, Anton +Sergeev, Alexander Heinecke, Allen Guo, Andreas Madsen, Ankesh Anand, Anton Loss, @Aravind, @Arie, Ashutosh Das, AuréLien Geron, Bairen Yi, @bakunyo, Ben Visser, Brady Zhou, Calpa Liu, Changming Sun, Chih Cheng Liang, Christopher Berner, Clark Zinzow, @Conchylicultor, Dan Ellis, Dan J, Dan Jarvis, Daniel diff --git a/configure.py b/configure.py index 83ee01c630fa90fb752f8c5ea163976d21a7b183..9afdd318dfd78c128708b8ae0d3bda8cb4ad56fc 100644 --- a/configure.py +++ b/configure.py @@ -34,8 +34,10 @@ except ImportError: _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.tf_configure.bazelrc') -_DEFAULT_CUDA_VERSION = '8.0' -_DEFAULT_CUDNN_VERSION = '6' +_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'WORKSPACE') +_DEFAULT_CUDA_VERSION = '9.0' +_DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' @@ -44,6 +46,13 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] + +_DEFAULT_PROMPT_ASK_ATTEMPTS = 10 + + +class UserInputError(Exception): + pass def is_windows(): @@ -158,7 +167,7 @@ def get_python_path(environ_cp, python_bin_path): try: library_paths = run_shell( [python_bin_path, '-c', - 'import site; print("\\n".join(site.getsitepackages()))']).split("\n") + 'import site; print("\\n".join(site.getsitepackages()))']).split('\n') except subprocess.CalledProcessError: library_paths = [run_shell( [python_bin_path, '-c', @@ -229,17 +238,9 @@ def setup_python(environ_cp): # Set-up env variables used by python_configure.bzl write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) - write_to_bazelrc('build --define PYTHON_BIN_PATH="%s"' % python_bin_path) - write_to_bazelrc('build --define PYTHON_LIB_PATH="%s"' % python_lib_path) write_to_bazelrc('build --force_python=py%s' % python_major_version) write_to_bazelrc('build --host_force_python=py%s' % python_major_version) write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) - write_to_bazelrc('test --force_python=py%s' % python_major_version) - write_to_bazelrc('test --host_force_python=py%s' % python_major_version) - write_to_bazelrc('test --define PYTHON_BIN_PATH="%s"' % python_bin_path) - write_to_bazelrc('test --define PYTHON_LIB_PATH="%s"' % python_lib_path) - write_to_bazelrc('run --define PYTHON_BIN_PATH="%s"' % python_bin_path) - write_to_bazelrc('run --define PYTHON_LIB_PATH="%s"' % python_lib_path) environ_cp['PYTHON_BIN_PATH'] = python_bin_path # Write tools/python_bin_path.sh @@ -488,10 +489,14 @@ def set_cc_opt_flags(environ_cp): cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', question, default_cc_opt_flags) for opt in cc_opt_flags.split(): - host_opt = '-march=native' # It should be safe on the same build host. - write_to_bazelrc( - 'build:opt --cxxopt=%s --copt=%s' % (opt, opt) + - ' --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt)) + write_to_bazelrc('build:opt --copt=%s' % opt) + # It should be safe on the same build host. + write_to_bazelrc('build:opt --host_copt=-march=native') + write_to_bazelrc('build:opt --define with_default_optimizations=true') + # TODO(mikecase): Remove these default defines once we are able to get + # TF Lite targets building without them. + write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') + write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') def set_tf_cuda_clang(environ_cp): @@ -561,6 +566,218 @@ def set_clang_cuda_compiler_path(environ_cp): clang_cuda_compiler_path) +def prompt_loop_or_load_from_env( + environ_cp, + var_name, + var_default, + ask_for_var, + check_success, + error_msg, + suppress_default_error=False, + n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS +): + """Loop over user prompts for an ENV param until receiving a valid response. + + For the env param var_name, read from the environment or verify user input + until receiving valid input. When done, set var_name in the environ_cp to its + new value. + + Args: + environ_cp: (Dict) copy of the os.environ. + var_name: (String) string for name of environment variable, e.g. "TF_MYVAR". + var_default: (String) default value string. + ask_for_var: (String) string for how to ask for user input. + check_success: (Function) function that takes one argument and returns a + boolean. Should return True if the value provided is considered valid. May + contain a complex error message if error_msg does not provide enough + information. In that case, set suppress_default_error to True. + error_msg: (String) String with one and only one '%s'. Formatted with each + invalid response upon check_success(input) failure. + suppress_default_error: (Bool) Suppress the above error message in favor of + one from the check_success function. + n_ask_attempts: (Integer) Number of times to query for valid input before + raising an error and quitting. + + Returns: + [String] The value of var_name after querying for input. + + Raises: + UserInputError: if a query has been attempted n_ask_attempts times without + success, assume that the user has made a scripting error, and will continue + to provide invalid input. Raise the error to avoid infinitely looping. + """ + default = environ_cp.get(var_name) or var_default + full_query = '%s [Default is %s]: ' % ( + ask_for_var, + default, + ) + + for _ in range(n_ask_attempts): + val = get_from_env_or_user_or_default(environ_cp, + var_name, + full_query, + default) + if check_success(val): + break + if not suppress_default_error: + print(error_msg % val) + environ_cp[var_name] = '' + else: + raise UserInputError('Invalid %s setting was provided %d times in a row. ' + 'Assuming to be a scripting mistake.' % + (var_name, n_ask_attempts)) + + environ_cp[var_name] = val + return val + + +def create_android_ndk_rule(environ_cp): + """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' % + environ_cp['APPDATA']) + elif is_macos(): + default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + else: + default_ndk_path = '%s/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + + def valid_ndk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'source.properties'))) + + android_ndk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_NDK_HOME', + var_default=default_ndk_path, + ask_for_var='Please specify the home path of the Android NDK to use.', + check_success=valid_ndk_path, + error_msg=('The path %s or its child file "source.properties" ' + 'does not exist.') + ) + + write_android_ndk_workspace_rule(android_ndk_home_path) + + +def create_android_sdk_rule(environ_cp): + """Set Android variables and write Android SDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA']) + elif is_macos(): + default_sdk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + else: + default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME'] + + def valid_sdk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'platforms')) and + os.path.exists(os.path.join(path, 'build-tools'))) + + android_sdk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_SDK_HOME', + var_default=default_sdk_path, + ask_for_var='Please specify the home path of the Android SDK to use.', + check_success=valid_sdk_path, + error_msg=('Either %s does not exist, or it does not contain the ' + 'subdirectories "platforms" and "build-tools".')) + + platforms = os.path.join(android_sdk_home_path, 'platforms') + api_levels = sorted(os.listdir(platforms)) + api_levels = [x.replace('android-', '') for x in api_levels] + + def valid_api_level(api_level): + return os.path.exists(os.path.join(android_sdk_home_path, + 'platforms', + 'android-' + api_level)) + + android_api_level = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_API_LEVEL', + var_default=api_levels[-1], + ask_for_var=('Please specify the Android SDK API level to use. ' + '[Available levels: %s]') % api_levels, + check_success=valid_api_level, + error_msg='Android-%s is not present in the SDK path.') + + build_tools = os.path.join(android_sdk_home_path, 'build-tools') + versions = sorted(os.listdir(build_tools)) + + def valid_build_tools(version): + return os.path.exists(os.path.join(android_sdk_home_path, + 'build-tools', + version)) + + android_build_tools_version = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_BUILD_TOOLS_VERSION', + var_default=versions[-1], + ask_for_var=('Please specify an Android build tools version to use. ' + '[Available versions: %s]') % versions, + check_success=valid_build_tools, + error_msg=('The selected SDK does not have build-tools version %s ' + 'available.')) + + write_android_sdk_workspace_rule(android_sdk_home_path, + android_build_tools_version, + android_api_level) + + +def write_android_sdk_workspace_rule(android_sdk_home_path, + android_build_tools_version, + android_api_level): + print('Writing android_sdk_workspace rule.\n') + with open(_TF_WORKSPACE, 'a') as f: + f.write(""" +android_sdk_repository( + name="androidsdk", + api_level=%s, + path="%s", + build_tools_version="%s")\n +""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) + + +def write_android_ndk_workspace_rule(android_ndk_home_path): + print('Writing android_ndk_workspace rule.') + ndk_api_level = check_ndk_level(android_ndk_home_path) + if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_api_level, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + with open(_TF_WORKSPACE, 'a') as f: + f.write(""" +android_ndk_repository( + name="androidndk", + path="%s", + api_level=%s)\n +""" % (android_ndk_home_path, ndk_api_level)) + + +def check_ndk_level(android_ndk_home_path): + """Check the revision number of an Android NDK path.""" + properties_path = '%s/source.properties' % android_ndk_home_path + if is_windows() or is_cygwin(): + properties_path = cygpath(properties_path) + with open(properties_path, 'r') as f: + filedata = f.read() + + revision = re.search(r'Pkg.Revision = (\d+)', filedata) + if revision: + return revision.group(1) + return None + + +def workspace_has_any_android_rule(): + """Check the WORKSPACE for existing android_*_repository rules.""" + with open(_TF_WORKSPACE, 'r') as f: + workspace = f.read() + has_any_rule = re.search(r'^android_[ns]dk_repository', + workspace, + re.MULTILINE) + return has_any_rule + + def set_gcc_host_compiler_path(environ_cp): """Set GCC_HOST_COMPILER_PATH.""" default_gcc_host_compiler_path = which('gcc') or '' @@ -570,23 +787,16 @@ def set_gcc_host_compiler_path(environ_cp): # os.readlink is only available in linux default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink) - ask_gcc_path = ( - 'Please specify which gcc should be used by nvcc as the ' - 'host compiler. [Default is %s]: ') % default_gcc_host_compiler_path - while True: - gcc_host_compiler_path = get_from_env_or_user_or_default( - environ_cp, 'GCC_HOST_COMPILER_PATH', ask_gcc_path, - default_gcc_host_compiler_path) - - if os.path.exists(gcc_host_compiler_path): - break - - # Reset and retry - print('Invalid gcc path. %s cannot be found' % gcc_host_compiler_path) - environ_cp['GCC_HOST_COMPILER_PATH'] = '' + gcc_host_compiler_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='GCC_HOST_COMPILER_PATH', + var_default=default_gcc_host_compiler_path, + ask_for_var= + 'Please specify which gcc should be used by nvcc as the host compiler.', + check_success=os.path.exists, + error_msg='Invalid gcc path. %s cannot be found.', + ) - # Set GCC_HOST_COMPILER_PATH - environ_cp['GCC_HOST_COMPILER_PATH'] = gcc_host_compiler_path write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) @@ -596,7 +806,7 @@ def set_tf_cuda_version(environ_cp): 'Please specify the CUDA SDK version you want to use, ' 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION - while True: + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): # Configure the Cuda SDK version to use. tf_cuda_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION) @@ -634,6 +844,11 @@ def set_tf_cuda_version(environ_cp): environ_cp['TF_CUDA_VERSION'] = '' environ_cp['CUDA_TOOLKIT_PATH'] = '' + else: + raise UserInputError('Invalid TF_CUDA_SETTING setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) + # Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path) @@ -647,7 +862,7 @@ def set_tf_cudnn_version(environ_cp): 'Please specify the cuDNN version you want to use. ' '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION - while True: + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_cudnn_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version, _DEFAULT_CUDNN_VERSION) @@ -706,6 +921,10 @@ def set_tf_cudnn_version(environ_cp): print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)) environ_cp['TF_CUDNN_VERSION'] = '' + else: + raise UserInputError('Invalid TF_CUDNN setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) # Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path @@ -814,76 +1033,66 @@ def set_other_cuda_vars(environ_cp): def set_host_cxx_compiler(environ_cp): """Set HOST_CXX_COMPILER.""" default_cxx_host_compiler = which('g++') or '' - ask_cxx_host_compiler = ( - 'Please specify which C++ compiler should be used as' - ' the host C++ compiler. [Default is %s]: ') % default_cxx_host_compiler - while True: - host_cxx_compiler = get_from_env_or_user_or_default( - environ_cp, 'HOST_CXX_COMPILER', ask_cxx_host_compiler, - default_cxx_host_compiler) - if os.path.exists(host_cxx_compiler): - break - - # Reset and retry - print('Invalid C++ compiler path. %s cannot be found' % host_cxx_compiler) - environ_cp['HOST_CXX_COMPILER'] = '' + host_cxx_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_CXX_COMPILER', + var_default=default_cxx_host_compiler, + ask_for_var=('Please specify which C++ compiler should be used as the ' + 'host C++ compiler.'), + check_success=os.path.exists, + error_msg='Invalid C++ compiler path. %s cannot be found.', + ) - # Set HOST_CXX_COMPILER - environ_cp['HOST_CXX_COMPILER'] = host_cxx_compiler write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler) def set_host_c_compiler(environ_cp): """Set HOST_C_COMPILER.""" default_c_host_compiler = which('gcc') or '' - ask_c_host_compiler = ( - 'Please specify which C compiler should be used as the' - ' host C compiler. [Default is %s]: ') % default_c_host_compiler - while True: - host_c_compiler = get_from_env_or_user_or_default( - environ_cp, 'HOST_C_COMPILER', ask_c_host_compiler, - default_c_host_compiler) - if os.path.exists(host_c_compiler): - break - - # Reset and retry - print('Invalid C compiler path. %s cannot be found' % host_c_compiler) - environ_cp['HOST_C_COMPILER'] = '' + host_c_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_C_COMPILER', + var_default=default_c_host_compiler, + ask_for_var=('Please specify which C compiler should be used as the host' + 'C compiler.'), + check_success=os.path.exists, + error_msg='Invalid C compiler path. %s cannot be found.', + ) - # Set HOST_C_COMPILER - environ_cp['HOST_C_COMPILER'] = host_c_compiler write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler) def set_computecpp_toolkit_path(environ_cp): """Set COMPUTECPP_TOOLKIT_PATH.""" - ask_computecpp_toolkit_path = ('Please specify the location where ComputeCpp ' - 'for SYCL %s is installed. [Default is %s]: ' - ) % (_TF_OPENCL_VERSION, - _DEFAULT_COMPUTECPP_TOOLKIT_PATH) - while True: - computecpp_toolkit_path = get_from_env_or_user_or_default( - environ_cp, 'COMPUTECPP_TOOLKIT_PATH', ask_computecpp_toolkit_path, - _DEFAULT_COMPUTECPP_TOOLKIT_PATH) + def toolkit_exists(toolkit_path): + """Check if a computecpp toolkit path is valid.""" if is_linux(): sycl_rt_lib_path = 'lib/libComputeCpp.so' else: sycl_rt_lib_path = '' - sycl_rt_lib_path_full = os.path.join(computecpp_toolkit_path, + sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path) - if os.path.exists(sycl_rt_lib_path_full): - break + exists = os.path.exists(sycl_rt_lib_path_full) + if not exists: + print('Invalid SYCL %s library path. %s cannot be found' % + (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) + return exists - print('Invalid SYCL %s library path. %s cannot be found' % - (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) - environ_cp['COMPUTECPP_TOOLKIT_PATH'] = '' + computecpp_toolkit_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='COMPUTECPP_TOOLKIT_PATH', + var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH, + ask_for_var=( + 'Please specify the location where ComputeCpp for SYCL %s is ' + 'installed.' % _TF_OPENCL_VERSION), + check_success=toolkit_exists, + error_msg='Invalid SYCL compiler path. %s cannot be found.', + suppress_default_error=True) - # Set COMPUTECPP_TOOLKIT_PATH - environ_cp['COMPUTECPP_TOOLKIT_PATH'] = computecpp_toolkit_path write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH', computecpp_toolkit_path) @@ -909,28 +1118,68 @@ def set_trisycl_include_dir(environ_cp): write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) +def set_trisycl_include_dir(environ_cp): + """Set TRISYCL_INCLUDE_DIR.""" + ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' + 'include directory. (Use --config=sycl_trisycl ' + 'when building with Bazel) ' + '[Default is %s]: ') % _DEFAULT_TRISYCL_INCLUDE_DIR + while True: + trisycl_include_dir = get_from_env_or_user_or_default( + environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, + _DEFAULT_TRISYCL_INCLUDE_DIR) + if os.path.exists(trisycl_include_dir): + break + + print('Invalid triSYCL include directory, %s cannot be found' + % (trisycl_include_dir)) + + # Set TRISYCL_INCLUDE_DIR + environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir + write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', + trisycl_include_dir) + + +def set_trisycl_include_dir(environ_cp): + """Set TRISYCL_INCLUDE_DIR.""" + + trisycl_include_dir = prompt_loop_or_load_from_env( + environ_cp, + var_name='TRISYCL_INCLUDE_DIR', + var_default=_DEFAULT_TRISYCL_INCLUDE_DIR, + ask_for_var=('Please specify the location of the triSYCL include ' + 'directory. (Use --config=sycl_trisycl when building with ' + 'Bazel)'), + check_success=os.path.exists, + error_msg='Invalid trySYCL include directory. %s cannot be found.', + suppress_default_error=True) + + write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) + + def set_mpi_home(environ_cp): """Set MPI_HOME.""" + default_mpi_home = which('mpirun') or which('mpiexec') or '' default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home)) - ask_mpi_home = ('Please specify the MPI toolkit folder. [Default is %s]: ' - ) % default_mpi_home - while True: - mpi_home = get_from_env_or_user_or_default(environ_cp, 'MPI_HOME', - ask_mpi_home, default_mpi_home) - - if os.path.exists(os.path.join(mpi_home, 'include')) and os.path.exists( - os.path.join(mpi_home, 'lib')): - break - - print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % - (os.path.join(mpi_home, 'include'), - os.path.exists(os.path.join(mpi_home, 'lib')))) - environ_cp['MPI_HOME'] = '' + def valid_mpi_path(mpi_home): + exists = (os.path.exists(os.path.join(mpi_home, 'include')) and + os.path.exists(os.path.join(mpi_home, 'lib'))) + if not exists: + print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % + (os.path.join(mpi_home, 'include'), + os.path.exists(os.path.join(mpi_home, 'lib')))) + return exists - # Set MPI_HOME - environ_cp['MPI_HOME'] = str(mpi_home) + _ = prompt_loop_or_load_from_env( + environ_cp, + var_name='MPI_HOME', + var_default=default_mpi_home, + ask_for_var='Please specify the MPI toolkit folder.', + check_success=valid_mpi_path, + error_msg='', + suppress_default_error=True) def set_other_mpi_vars(environ_cp): @@ -968,13 +1217,12 @@ def set_other_mpi_vars(environ_cp): def set_mkl(): write_to_bazelrc('build:mkl --define using_mkl=true') write_to_bazelrc('build:mkl -c opt') - write_to_bazelrc('build:mkl --copt="-DEIGEN_USE_VML"') print( 'Add "--config=mkl" to your bazel command to build with MKL ' 'support.\nPlease note that MKL on MacOS or windows is still not ' 'supported.\nIf you would like to use a local MKL instead of ' 'downloading, please set the environment variable \"TF_MKL_ROOT\" every ' - 'time before build.') + 'time before build.\n') def set_monolithic(): @@ -1003,6 +1251,19 @@ def create_android_bazelrc_configs(): write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a') +def set_grpc_build_flags(): + write_to_bazelrc('build --define grpc_no_ares=true') + +def set_windows_build_flags(): + if is_windows(): + # The non-monolithic build is not supported yet + write_to_bazelrc('build --config monolithic') + # Suppress warning messages + write_to_bazelrc('build --copt=-w --host_copt=-w') + # Output more verbose information when something goes wrong + write_to_bazelrc('build --verbose_failures') + + def main(): # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. @@ -1023,7 +1284,6 @@ def main(): environ_cp['TF_NEED_OPENCL_SYCL'] = '0' environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' - environ_cp['TF_NEED_S3'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' if is_macos(): @@ -1077,10 +1337,29 @@ def main(): set_mpi_home(environ_cp) set_other_mpi_vars(environ_cp) + set_grpc_build_flags() set_cc_opt_flags(environ_cp) set_mkl() set_monolithic() + set_windows_build_flags() create_android_bazelrc_configs() + if workspace_has_any_android_rule(): + print('The WORKSPACE file has at least one of ["android_sdk_repository", ' + '"android_ndk_repository"] already set. Will not ask to help ' + 'configure the WORKSPACE. Please delete the existing rules to ' + 'activate the helper.\n') + else: + if get_var( + environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', + False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), + 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): + create_android_ndk_rule(environ_cp) + create_android_sdk_rule(environ_cp) + + if __name__ == '__main__': main() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 9874f95ea3268dfce0158d3ddcdefea77136cad8..646137c0c81a6de11fe06905dbf44d9656033d29 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -119,7 +119,7 @@ config_setting( config_setting( name = "no_tensorflow_py_deps", - values = {"define": "no_tensorflow_py_deps=true"}, + define_values = {"no_tensorflow_py_deps": "true"}, visibility = ["//visibility:public"], ) @@ -175,55 +175,122 @@ config_setting( # TODO(jhseu): Enable on other platforms other than Linux. config_setting( name = "with_jemalloc_linux_x86_64", - values = { - "cpu": "k8", - "define": "with_jemalloc=true", - }, + define_values = {"with_jemalloc": "true"}, + values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) config_setting( name = "with_jemalloc_linux_ppc64le", - values = { - "cpu": "ppc", - "define": "with_jemalloc=true", - }, + define_values = {"with_jemalloc": "true"}, + values = {"cpu": "ppc"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_default_optimizations", + define_values = {"with_default_optimizations": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_gcp_support", - values = {"define": "with_gcp_support=true"}, + define_values = {"with_gcp_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_hdfs_support", - values = {"define": "with_hdfs_support=true"}, + define_values = {"with_hdfs_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_s3_support", - values = {"define": "with_s3_support=true"}, + define_values = {"with_s3_support": "true"}, + visibility = ["//visibility:public"], +) + +# Crosses between platforms and file system libraries not supported on those +# platforms due to limitations in nested select() statements. +config_setting( + name = "with_gcp_support_windows_override", + define_values = {"with_gcp_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_hdfs_support_windows_override", + define_values = {"with_hdfs_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_s3_support_windows_override", + define_values = {"with_s3_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_gcp_support_android_override", + define_values = {"with_gcp_support": "true"}, + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_hdfs_support_android_override", + define_values = {"with_hdfs_support": "true"}, + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_s3_support_android_override", + define_values = {"with_s3_support": "true"}, + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_gcp_support_ios_override", + define_values = {"with_gcp_support": "true"}, + values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_hdfs_support_ios_override", + define_values = {"with_hdfs_support": "true"}, + values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_s3_support_ios_override", + define_values = {"with_s3_support": "true"}, + values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, visibility = ["//visibility:public"], ) config_setting( name = "with_xla_support", - values = {"define": "with_xla_support=true"}, + define_values = {"with_xla_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_gdr_support", - values = {"define": "with_gdr_support=true"}, + define_values = {"with_gdr_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_verbs_support", - values = {"define": "with_verbs_support=true"}, + define_values = {"with_verbs_support": "true"}, visibility = ["//visibility:public"], ) @@ -297,17 +364,10 @@ config_setting( visibility = ["//visibility:public"], ) -# Make a dummy rule that we can chaqnge "default" in select statements to. -# to disable dependencies in copybara. -config_setting( - name = "dummy_disabled_internal", - values = {"define": "with_dummy_disabled_internal=true"}, - visibility = ["//visibility:public"], -) - package_group( name = "internal", packages = [ + "//learning/meta_rank/...", "//tensorflow/...", "//tensorflow_fold/llgtm/...", ], @@ -353,6 +413,7 @@ filegroup( "//tensorflow/compiler/tf2xla:all_files", "//tensorflow/compiler/tf2xla/cc:all_files", "//tensorflow/compiler/tf2xla/kernels:all_files", + "//tensorflow/compiler/tf2xla/lib:all_files", "//tensorflow/compiler/tf2xla/ops:all_files", "//tensorflow/compiler/xla:all_files", "//tensorflow/compiler/xla/client:all_files", @@ -393,6 +454,7 @@ filegroup( "//tensorflow/contrib/data/python/ops:all_files", "//tensorflow/contrib/decision_trees/proto:all_files", "//tensorflow/contrib/distributions:all_files", + "//tensorflow/contrib/eager/proto:all_files", "//tensorflow/contrib/eager/python:all_files", "//tensorflow/contrib/estimator:all_files", "//tensorflow/contrib/factorization:all_files", @@ -425,6 +487,25 @@ filegroup( "//tensorflow/contrib/learn/python/learn/datasets:all_files", "//tensorflow/contrib/linalg:all_files", "//tensorflow/contrib/linear_optimizer:all_files", + "//tensorflow/contrib/lite:all_files", + "//tensorflow/contrib/lite/java:all_files", + "//tensorflow/contrib/lite/java/demo/app/src/main:all_files", + "//tensorflow/contrib/lite/java/demo/app/src/main/assets:all_files", + "//tensorflow/contrib/lite/java/src/main/native:all_files", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:all_files", + "//tensorflow/contrib/lite/kernels:all_files", + "//tensorflow/contrib/lite/kernels/internal:all_files", + "//tensorflow/contrib/lite/models/smartreply:all_files", + "//tensorflow/contrib/lite/nnapi:all_files", + "//tensorflow/contrib/lite/python:all_files", + "//tensorflow/contrib/lite/schema:all_files", + "//tensorflow/contrib/lite/testing:all_files", + "//tensorflow/contrib/lite/toco:all_files", + "//tensorflow/contrib/lite/toco/graph_transformations/tests:all_files", + "//tensorflow/contrib/lite/toco/python:all_files", + "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:all_files", + "//tensorflow/contrib/lite/toco/tflite:all_files", + "//tensorflow/contrib/lite/tools:all_files", "//tensorflow/contrib/lookup:all_files", "//tensorflow/contrib/losses:all_files", "//tensorflow/contrib/makefile:all_files", @@ -474,6 +555,7 @@ filegroup( "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:all_files", "//tensorflow/contrib/tpu:all_files", "//tensorflow/contrib/tpu/profiler:all_files", + "//tensorflow/contrib/tpu/proto:all_files", "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", "//tensorflow/contrib/verbs:all_files", @@ -521,6 +603,7 @@ filegroup( "//tensorflow/java/src/main/native:all_files", "//tensorflow/python:all_files", "//tensorflow/python/data:all_files", + "//tensorflow/python/data/kernel_tests:all_files", "//tensorflow/python/data/ops:all_files", "//tensorflow/python/data/util:all_files", "//tensorflow/python/debug:all_files", @@ -557,6 +640,7 @@ filegroup( "//tensorflow/tools/test:all_files", "//tensorflow/user_ops:all_files", "//third_party/hadoop:all_files", + "//third_party/mpi:all_files", "//third_party/sycl:all_files", "//third_party/sycl/sycl:all_files", ], @@ -687,3 +771,10 @@ tf_cc_shared_object( "//tensorflow/core:tensorflow", ], ) + +exports_files( + [ + "tf_version_script.lds", + "tf_exported_symbols.lds", + ], +) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 6dd1b999102d0135720b6ab3a43cbe61255acbc1..9b5704702841081d7dde78ac019305140066f688 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -383,12 +383,11 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, // be less than the total node count. Status ValidateNoCycles(const Graph& g) { // TODO(nolivia): check this on a subset of the graph instead of all of it. - int total_num_nodes = g.num_node_ids(); // A node is ready when all of its inputs have been visited. std::vector ready; - std::vector pending_count(total_num_nodes, 0); + std::vector pending_count(g.num_node_ids(), 0); - for (int i = 0; i < total_num_nodes; ++i) { + for (int i = 0; i < g.num_node_ids(); ++i) { const Node* n = g.FindNodeId(i); if (n == nullptr) continue; pending_count[i] = n->in_edges().size(); @@ -421,7 +420,7 @@ Status ValidateNoCycles(const Graph& g) { } } - if (processed < total_num_nodes) { + if (processed < g.num_nodes()) { std::vector nodes_in_cycle; for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; ++i) { @@ -430,7 +429,7 @@ Status ValidateNoCycles(const Graph& g) { } } return errors::InvalidArgument( - "Graph is invalid, contains a cycle with ", total_num_nodes - processed, + "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); } return Status::OK(); @@ -580,6 +579,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, status->status = InvalidArgument( "invalid string tensor encoding (string #", i, " of ", srcarray.size(), "): ", status->status.error_message()); + delete[] base; return nullptr; } dst += consumed; @@ -589,6 +589,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, status->status = InvalidArgument( "invalid string tensor encoding (decoded ", (dst - base), " bytes, but the tensor is encoded in ", size, " bytes"); + delete[] base; return nullptr; } @@ -625,6 +626,23 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, return Status::OK(); } +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + // If any session has already run this node_id, mark this session as + // unrunnable. + for (auto it : graph->sessions) { + if (it.first->last_num_graph_nodes > op.node.id()) { + it.second = FailedPrecondition( + "Operation '", op.node.DebugString(), "' was changed by ", + mutation_type, + " after it was run by a session. Nodes can be mutated " + "only before they are executed by a session. Either don't modify " + "nodes after running them or create a new session."); + } + } +} + // Helpers for loading a TensorFlow plugin (a .so file). Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); @@ -890,8 +908,8 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, TF_Status* status) { const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); if (attr == nullptr) { - status->status = - InvalidArgument("Operation has no attr named '", attr_name, "'."); + status->status = InvalidArgument("Operation '", oper->node.name(), + "' has no attr named '", attr_name, "'."); } return attr; } @@ -939,13 +957,17 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, return; } - std::vector dim_vec; - dim_vec.reserve(num_dims); - for (int i = 0; i < num_dims; ++i) { - dim_vec.push_back(ic->MakeDim(dims[i])); + tensorflow::shape_inference::ShapeHandle new_shape; + if (num_dims != -1) { + std::vector dim_vec; + dim_vec.reserve(num_dims); + for (int i = 0; i < num_dims; ++i) { + dim_vec.push_back(ic->MakeDim(dims[i])); + } + new_shape = ic->MakeShape(dim_vec); + } else { + new_shape = ic->UnknownShape(); } - - tensorflow::shape_inference::ShapeHandle new_shape = ic->MakeShape(dim_vec); status->status = graph->refiner.SetShape(node, output.index, new_shape); } @@ -1741,7 +1763,6 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, TF_Graph::TF_Graph() : graph(tensorflow::OpRegistry::Global()), refiner(graph.versions().producer(), graph.op_registry()), - num_sessions(0), delete_requested(false), parent(nullptr), parent_inputs(nullptr) {} @@ -1751,7 +1772,7 @@ TF_Graph* TF_NewGraph() { return new TF_Graph; } void TF_DeleteGraph(TF_Graph* g) { g->mu.lock(); g->delete_requested = true; - const bool del = g->num_sessions == 0; + const bool del = g->sessions.empty(); g->mu.unlock(); if (del) delete g; } @@ -1831,6 +1852,16 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, opts->opts.prefix = prefix; } +void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_names) { + opts->opts.uniquify_names = uniquify_names; +} + +void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_prefix) { + opts->opts.uniquify_prefix = uniquify_prefix; +} + void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst) { @@ -1888,12 +1919,12 @@ void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, *opers = results->return_nodes.data(); } -void TF_ImportGraphDefResultsUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_unused_input_mappings, +void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, const char*** src_names, int** src_indexes) { - *num_unused_input_mappings = results->unused_key_names.size(); - *src_names = results->unused_key_names.data(); - *src_indexes = results->unused_key_indexes.data(); + *num_missing_unused_input_mappings = results->missing_unused_key_names.size(); + *src_names = results->missing_unused_key_names.data(); + *src_indexes = results->missing_unused_key_indexes.data(); } void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { @@ -1933,18 +1964,21 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); } - // Populate unused map keys - DCHECK(tf_results->unused_key_names.empty()); - DCHECK(tf_results->unused_key_indexes.empty()); - DCHECK(tf_results->unused_key_names_data.empty()); - tf_results->unused_key_names.resize(results.unused_input_map_keys.size()); - tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size()); - for (int i = 0; i < results.unused_input_map_keys.size(); ++i) { - TensorId id = results.unused_input_map_keys[i]; - tf_results->unused_key_names_data.push_back(id.first.ToString()); - tf_results->unused_key_names[i] = - tf_results->unused_key_names_data.back().c_str(); - tf_results->unused_key_indexes[i] = id.second; + // Populate missing unused map keys + DCHECK(tf_results->missing_unused_key_names.empty()); + DCHECK(tf_results->missing_unused_key_indexes.empty()); + DCHECK(tf_results->missing_unused_key_names_data.empty()); + + size_t size = results.missing_unused_input_map_keys.size(); + tf_results->missing_unused_key_names.resize(size); + tf_results->missing_unused_key_indexes.resize(size); + + for (int i = 0; i < size; ++i) { + TensorId id = results.missing_unused_input_map_keys[i]; + tf_results->missing_unused_key_names_data.push_back(id.first.ToString()); + tf_results->missing_unused_key_names[i] = + tf_results->missing_unused_key_names_data.back().c_str(); + tf_results->missing_unused_key_indexes[i] = id.second; } } @@ -2321,11 +2355,12 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, Session* session; status->status = NewSession(opt->options, &session); if (status->status.ok()) { + TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); - graph->num_sessions += 1; + graph->sessions[new_session] = Status::OK(); } - return new TF_Session(session, graph); + return new_session; } else { DCHECK_EQ(nullptr, session); return nullptr; @@ -2389,7 +2424,7 @@ TF_Session* TF_LoadSessionFromSavedModel( TF_Session* session = new TF_Session(bundle.session.release(), graph); - graph->num_sessions += 1; + graph->sessions[session] = Status::OK(); session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; #endif // __ANDROID__ @@ -2404,8 +2439,8 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) { TF_Graph* const graph = s->graph; if (graph != nullptr) { graph->mu.lock(); - graph->num_sessions -= 1; - const bool del = graph->delete_requested && graph->num_sessions == 0; + graph->sessions.erase(s); + const bool del = graph->delete_requested && graph->sessions.empty(); graph->mu.unlock(); if (del) delete graph; } @@ -2421,6 +2456,13 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { mutex_lock session_lock(session->mu); session->graph->mu.lock(); const Graph& graph = session->graph->graph; + + status->status = session->graph->sessions[session]; + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { status->status = tensorflow::ValidateNoCycles(session->graph->graph); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index bb569d67fcbcec29e9494236abd79b3e40db91cd..de9527f86d1f48846b160230c592a398e00e10c5 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -889,6 +889,20 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( TF_ImportGraphDefOptions* opts, const char* prefix); +// Set whether to uniquify imported operation names. If true, imported operation +// names will be modified if their name already exists in the graph. If false, +// conflicting names will be treated as an error. Note that this option has no +// effect if a prefix is set, since the prefix will guarantee all names are +// unique. Defaults to false. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyNames( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_names); + +// If true, the specified prefix will be modified if it already exists as an +// operation name or prefix in the graph. If false, a conflicting prefix will be +// treated as an error. This option has no effect if no prefix is specified. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_prefix); + // Set any imported nodes with input `src_name:src_index` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references a node already existing in the graph being imported into. @@ -948,16 +962,16 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); // Fetches any input mappings requested via -// TF_ImportGraphDefOptionsAddInputMapping() that weren't used as input to any -// node in the imported graph def. The number of fetched mappings is returned in -// `num_unused_input_mappings`. The array of each mapping's source node name is -// returned in `src_names`, and the array of each mapping's source index is -// returned in `src_indexes`. +// TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the GraphDef +// and weren't used as input to any node in the imported graph def. The number +// of fetched mappings is returned in `num_missing_unused_input_mappings`. The +// array of each mapping's source node name is returned in `src_names`, and the +// array of each mapping's source index is returned in `src_indexes`. // // `*src_names`, `*src_indexes`, and the memory backing each string in // `src_names` are owned by and have the lifetime of `results`. -TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_unused_input_mappings, +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, const char*** src_names, int** src_indexes); // Deletes a results object returned by TF_GraphImportGraphDefWithResults(). diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index dcb818b88b6fca460852beb6e948d2eb6964f663..d60d1de315ed37a327bd036ddb914a3c32413f65 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -68,7 +68,7 @@ class NodeNameMapping { // This is a superset of values in name_mapping_. std::unordered_set used_names_; // Mapping from original node name from the graph to the normalized - // and uniqified version of it. + // and uniquified version of it. std::unordered_map name_mapping_; }; @@ -226,12 +226,17 @@ Status FillFunctionBody( } node_def->add_input(strings::StrCat("^", normalized)); } + + // A function is stateful if any of its nodes are stateful. + if (node->op_def().is_stateful()) { + fdef->mutable_signature()->set_is_stateful(true); + } } return Status::OK(); } // Graph to FunctionDef conversion. This code is closely modeled on the Python -// code in third_party/tensorflow/python/framework/function.py. +// code in tensorflow/python/framework/function.py. Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, bool append_hash_to_fn_name, const std::vector& body_nodes, diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index d5580b658992413ae6f9cb79ef88751ee28ce465..2e2293ca85175009d1bcc8db5c830789e4701c1d 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1482,6 +1482,51 @@ TEST_F(CApiFunctionTest, GetOpDef) { EXPECT_EQ(op_def.name(), func_name_); EXPECT_EQ(op_def.input_arg_size(), 1); EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_FALSE(op_def.is_stateful()); + + TF_DeleteBuffer(buffer); +} + +void DefineStatefulFunction(const char* name, TF_Function** func) { + std::unique_ptr func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Tensor* tensor_shape = Int32Tensor({37, 1}); + TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape"); + TF_Operation* random = + RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get()); + + TF_Output inputs[] = {}; + TF_Output outputs[] = {{random, 0}}; + *func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/false, -1, + /*opers=*/nullptr, 0, inputs, 1, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, "", s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(*func, nullptr); + TF_DeleteTensor(tensor_shape); +} + +TEST_F(CApiFunctionTest, StatefulOpDef) { + DefineStatefulFunction(func_name_, &func_); + TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Test we can retrieve function OpDef from graph + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Sanity check returned OpDef + string data(static_cast(buffer->data), buffer->length); + OpDef op_def; + op_def.ParseFromString(data); + EXPECT_EQ(op_def.name(), func_name_); + EXPECT_EQ(op_def.input_arg_size(), 0); + EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_TRUE(op_def.is_stateful()); TF_DeleteBuffer(buffer); } diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index bb04e01beec931a8ea66d0855eec9625d3a6a5ab..6df77a7f9baed999a2f2cb5e9404cb63451b6212 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -81,12 +81,20 @@ struct TF_Graph { std::unordered_map name_map GUARDED_BY(mu); - // TF_Graph may only / must be deleted when - // num_sessions == 0 && delete_requested == true - - // num_sessions incremented by TF_NewSession, and decremented by + // The keys of this map are all the active sessions using this graph. + // Each value is the current "runnability" status of the corresponding + // session. Under normal conditions all statuses are Status::OK(), but + // if some operation is mutated after it was run by a session (this + // is detected in RecordMutation function), that session is no longer + // safe to run. Its status will contain the error that will be returned + // to the user, should she try running this session. + // + // Sessions are added to this map in TF_NewSession, and removed in // TF_DeleteSession. - int num_sessions GUARDED_BY(mu); + // TF_Graph may only / must be deleted when + // sessions.size() == 0 && delete_requested == true + tensorflow::gtl::FlatMap sessions + GUARDED_BY(mu); bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph // Used to link graphs contained in TF_WhileParams to the parent graph that @@ -135,11 +143,11 @@ struct TF_ImportGraphDefOptions { struct TF_ImportGraphDefResults { std::vector return_tensors; std::vector return_nodes; - std::vector unused_key_names; - std::vector unused_key_indexes; + std::vector missing_unused_key_names; + std::vector missing_unused_key_indexes; - // Backing memory for unused_key_names values. - std::list unused_key_names_data; + // Backing memory for missing_unused_key_names values. + std::list missing_unused_key_names_data; }; struct TF_DeviceList { @@ -167,6 +175,9 @@ TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 05881e619ba232de99e78f315cfa8ab9294e5137..4e89b4fc43973e4cc9a6c64f50e288d81bf22033 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -287,6 +287,13 @@ TEST(CAPI, SetShape) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); EXPECT_EQ(-1, num_dims); + // Set the shape to be unknown, expect no change. + TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(-1, num_dims); + // Set the shape to be 2 x Unknown int64_t dims[] = {2, -1}; TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); @@ -315,7 +322,17 @@ TEST(CAPI, SetShape) { EXPECT_EQ(dims[0], returned_dims[0]); EXPECT_EQ(dims[1], returned_dims[1]); - // Try to set 'unknown' on the shape and see that + // Try to set 'unknown' with unknown rank on the shape and see that + // it doesn't change. + TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(2, num_dims); + EXPECT_EQ(2, returned_dims[0]); + EXPECT_EQ(3, returned_dims[1]); + + // Try to set 'unknown' with same rank on the shape and see that // it doesn't change. dims[0] = -1; dims[1] = -1; @@ -383,7 +400,7 @@ TEST(CAPI, Graph) { EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s)); - EXPECT_EQ(string("Operation has no attr named 'missing'."), + EXPECT_EQ(string("Operation 'feed' has no attr named 'missing'."), string(TF_Message(s))); // Make a constant oper with the scalar "3". @@ -756,7 +773,7 @@ TEST(CAPI, ImportGraphDef_WithReturnOutputs) { TF_DeleteStatus(s); } -TEST(CAPI, ImportGraphDef_UnusedInputMappings) { +TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -799,7 +816,7 @@ TEST(CAPI, ImportGraphDef_UnusedInputMappings) { int num_unused_input_mappings; const char** src_names; int* src_indexes; - TF_ImportGraphDefResultsUnusedInputMappings( + TF_ImportGraphDefResultsMissingUnusedInputMappings( results, &num_unused_input_mappings, &src_names, &src_indexes); ASSERT_EQ(1, num_unused_input_mappings); EXPECT_EQ(string("fake"), string(src_names[0])); @@ -1054,7 +1071,7 @@ class CApiColocationTest : public ::testing::Test { TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_); if (expected.empty()) { ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); - EXPECT_EQ(std::string("Operation has no attr named '_class'."), + EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."), std::string(TF_Message(s_))); return; } diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index c291a2e440a8515e968b0ce0395b289080f04e8b..37439ff0beac5a5220460465e954b6c093ee1ba9 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -193,6 +193,15 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, return TF_FinishOperation(desc, s); } +TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, + TF_Graph* graph, TF_Status* s) { + TF_OperationDescription* desc = + TF_NewOperation(graph, "RandomUniform", "random_uniform"); + TF_AddInput(desc, {shape, 0}); + TF_SetAttrType(desc, "dtype", dtype); + return TF_FinishOperation(desc, s); +} + void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s, const char* name, TF_Operation** op) { TF_Operation* zero = ScalarConst( diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index bc44a7b840bd15f42a3a469d3964e1e1e8d48cdc..3429009a71a863ae6b69b5cd29ace3c7fd078f4c 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -74,6 +74,9 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s, TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s); +TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, + TF_Graph* graph, TF_Status* s); + // Split `input` along the first dimension into 3 tensors TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, const char* name = "split3"); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c77896b80b478cd34d3502e1061a7e76204ba021..d533758e360bc44a6f52f57eaae5b222e0482860 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -39,6 +39,7 @@ tf_cuda_library( tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], + visibility = ["//tensorflow:internal"], deps = [ ":c_api", ":runtime", @@ -105,7 +106,6 @@ tf_cc_test( cc_library( name = "tape", - srcs = ["tape.cc"], hdrs = ["tape.h"], visibility = ["//tensorflow:internal"], deps = [ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 8359de62b7ff690fec9f6a0e3280f947c62f8b6e..706c89536db019c7f7389af576815746b2425520 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -571,6 +571,12 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, status->status = ctx->func_lib_def.AddFunctionDef(function_def); } +void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, + TF_Status* status) { + tensorflow::mutex_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); +} + } // extern "C" TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 865580c5f3a823d9cf49fe460bd007e3b3b88767..ca105962df0d6655946304159937621022e7fcba 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -200,6 +200,13 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx, const char* serialized_function_def, size_t size, TF_Status* status); +// Adds a function (created from TF_GraphToFunction or +// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with +// TFE_Execute by creating an op with the same name as the function. +TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, + TF_Function* function, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 4af91b8853d0e85570bad136752a9d0a04b87da5..3fe0b7efa11bc619ed98bf9a1634ade5b6ed0a7c 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -295,6 +295,67 @@ TEST(CAPI, Execute) { TF_DeleteStatus(status); } +TEST(CAPI, Function) { + // First create a simple identity function. + TF_Graph* function_graph = TF_NewGraph(); + TF_OperationDescription* arg_descr = + TF_NewOperation(function_graph, "Placeholder", "arg"); + TF_SetAttrType(arg_descr, "dtype", TF_INT32); + TF_Status* status = TF_NewStatus(); + TF_Operation* arg = TF_FinishOperation(arg_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_OperationDescription* id_descr = + TF_NewOperation(function_graph, "Identity", "id"); + TF_SetAttrType(id_descr, "T", TF_INT32); + TF_AddInput(id_descr, {arg, 0}); + TF_Operation* id = TF_FinishOperation(id_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_Output input{arg, 0}; + TF_Output output{id, 0}; + TF_Function* fn = + TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, + &output, nullptr, nullptr, "test", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteGraph(function_graph); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_ContextAddFunction(ctx, fn, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteFunction(fn); + + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); + + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); + + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + TFE_DeleteContext(ctx, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteStatus(status); +} + string MatMulFunction() { tensorflow::FunctionDef def; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( diff --git a/tensorflow/c/eager/tape.cc b/tensorflow/c/eager/tape.cc deleted file mode 100644 index 464612a81ebda428f5582b6927f3a3b00a5aa6f5..0000000000000000000000000000000000000000 --- a/tensorflow/c/eager/tape.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/c/eager/tape.h" - -namespace tensorflow { -namespace eager { - -bool GradientTape::ShouldRecord(gtl::ArraySlice tensor_ids) { - for (int64 i : tensor_ids) { - if (tensor_tape_.find(i) != tensor_tape_.end()) { - return true; - } - } - return false; -} - -void GradientTape::Watch(int64 tensor_id) { - tensor_tape_.emplace(tensor_id, -1); -} - -void GradientTape::RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, void* backward_function, - const std::function& backward_function_deleter) { - if (!ShouldRecord(input_tensor_id)) { - backward_function_deleter(); - return; - } - std::vector ids; - ids.reserve(input_tensor_id.size()); - for (int64 i : input_tensor_id) { - tensor_usage_[i]++; - ids.push_back(i); - } - const int64 op_id = next_op_id_++; - std::vector tensors; - tensors.reserve(output_tensors.size()); - for (const TapeTensor& o : output_tensors) { - // Note: the tensor can have already been watched and hence be in the tape, - // so we cannot check that we're inserting it here. - tensor_tape_[o.id] = op_id; - tensor_usage_[o.id] = 1; - tensors.push_back(o); - } - op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function, - backward_function_deleter}; -} - -void GradientTape::DeleteTrace(int64 tensor_id) { - auto it = tensor_usage_.find(tensor_id); - if (it == tensor_usage_.end()) { - return; - } - it->second--; - if (it->second != 0) { - return; - } - tensor_usage_.erase(it); - auto tensor_op_it = tensor_tape_.find(tensor_id); - if (tensor_op_it == tensor_tape_.end()) { - return; - } - const int64 op_id = tensor_op_it->second; - if (op_id == -1) { - // Do not delete watched tensors. - return; - } - tensor_tape_.erase(tensor_op_it); - auto op_it = op_tape_.find(op_id); - CHECK(op_it != op_tape_.end()); - for (const auto& output : op_it->second.output_tensor_info) { - if (tensor_usage_.find(output.id) != tensor_usage_.end()) { - // Found a usage for an output, so cannot delete the op. - return; - } - } - for (int64 id : op_it->second.input_tensor_id) { - DeleteTrace(id); - } - op_it->second.backward_function_deleter(); - op_tape_.erase(op_it); -} - -std::pair GradientTape::Export() { - return {std::move(tensor_tape_), std::move(op_tape_)}; -} - -} // namespace eager -} // namespace tensorflow diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index df51f300eb61d54cb1e06d5a58a9b10e834f73c4..20ed037c52f34bc7a8aa39243c0b85e58fee1d46 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -19,6 +19,7 @@ limitations under the License. // maintains the data structures required to do so. #include +#include #include #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" @@ -36,13 +37,14 @@ struct TapeTensor { }; // Represents an entry in the tape. +template struct OpTapeEntry { string op_type; std::vector output_tensor_info; std::vector input_tensor_id; // TODO(apassos) consider narrowing down this interface. - void* backward_function; + BackwardFunction* backward_function; // Should be called before deleting the backward function. TODO(apassos) use // unique_ptrs to ensure this happens. @@ -55,13 +57,78 @@ struct OpTapeEntry { using TensorTape = std::unordered_map; // Map from operation-id to tape entry. -using OpTape = std::unordered_map; +template +using OpTape = std::unordered_map>; + +// Operations the tape needs to perform on tensors to do backpropagation. Named +// "vspace" because a subset of these are related to a vector space, such as +// adding gradients, getting zeroes, etc. Currently cannot be implemented +// without using tensorflow python code, hence left unspecified here. +// +// Gradient is the type returned by gradient functions. In Python TF it's either +// Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need +// to allow their size to be computed and they need to be passable to a backward +// function and deleted (as the backprop code creates lots of gradients the user +// is not interested in). +// +// BackwardFunction needs to be a closure which stores intermediate activations +// from the forward computation and calls a vector-jacobian product function +// (also known as adjoint function) to compute, given downstream gradients, +// upstream gradients. +// +// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle +// specialization, which is blocked by quite a few things needing to loop back +// into python now. +template +class VSpace { + public: + virtual ~VSpace() {} + + // Returns the number of elements in the gradient tensor. + virtual int64 NumElements(Gradient* tensor) const = 0; + + // Consumes references to the tensors in the gradient_tensors list and returns + // a tensor with the result. + virtual Gradient* AggregateGradients( + gtl::ArraySlice gradient_tensors) const = 0; + + // Returns a tensor of the right shape and dtype filled with zeros. + virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0; + + // Returns a Tensor which is filled with ones and like the input. + virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0; + + // Calls the passed-in backward function. + virtual Status CallBackwardFunction( + BackwardFunction* backward_function, + gtl::ArraySlice output_gradients, + std::vector* result) const = 0; + + // Deletes the input tensor. + virtual void DeleteGradient(Gradient* gradient) const = 0; + + // Lets this VSpace know that it can release resources held by the + // `backward_function`, It will not be called again. + // `backward_function` must not be null. + virtual void ReleaseBackwardFunction( + BackwardFunction* backward_function) const = 0; +}; // Traces the execution of operations, doing eager garbage collection, and // exporting a full trace so other code can do backpropagation. Not thread-safe. +template class GradientTape { public: - GradientTape() {} + // If `persistent` is true, GradientTape will not eagerly delete backward + // functions (and hence the tensors they keep alive). Instead, everything + // is deleted in ~GradientTape. Persistent GradientTapes are useful when + // users want to compute multiple gradients over the same tape. + GradientTape(bool persistent) : persistent_(persistent) {} + ~GradientTape() { + for (const auto& pair : op_tape_) { + pair.second.backward_function_deleter(); + } + } bool ShouldRecord(gtl::ArraySlice tensor_ids); @@ -70,26 +137,486 @@ class GradientTape { void RecordOperation(const string& op_type, gtl::ArraySlice output_tensors, gtl::ArraySlice input_tensor_id, - void* backward_function, + BackwardFunction* backward_function, const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); - // Note: it is only valid to call Export once per tape, and after calling - // export the tape is no longer valid (i.e. calls to ShouldRecord, Watch, - // Record, and Delete have undefined behavior). - std::pair Export(); + // Consumes the internal state of the tape (so cannot be called more than + // once) and produces the gradient of the target tensors with respect to the + // source tensors. The output gradients are used if not empty and not + // null. The result is populated with one tensor per target element. + Status ComputeGradient(const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice source_tensor_id, + gtl::ArraySlice output_gradients, + std::vector* result); private: TensorTape tensor_tape_; - OpTape op_tape_; + OpTape op_tape_; int64 next_op_id_{0}; // Map from tensor id to number of remaining usages (i.e. how many entries in // the tape refer to it); to aid in tape garbage collection. std::unordered_map tensor_usage_; + + // If false, all activations are deleted in the first call to ComputeGradient. + // Else, only when this is destructed. + bool persistent_; +}; + +// Template instantiations here + +template +bool GradientTape::ShouldRecord( + gtl::ArraySlice tensor_ids) { + for (int64 i : tensor_ids) { + if (tensor_tape_.find(i) != tensor_tape_.end()) { + return true; + } + } + return false; +} + +template +void GradientTape::Watch(int64 tensor_id) { + tensor_tape_.emplace(tensor_id, -1); +} + +template +void GradientTape::RecordOperation( + const string& op_type, gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, BackwardFunction* backward_function, + const std::function& backward_function_deleter) { + if (!ShouldRecord(input_tensor_id)) { + backward_function_deleter(); + return; + } + std::vector ids; + ids.reserve(input_tensor_id.size()); + for (int64 i : input_tensor_id) { + tensor_usage_[i]++; + ids.push_back(i); + } + const int64 op_id = next_op_id_++; + std::vector tensors; + tensors.reserve(output_tensors.size()); + for (const TapeTensor& o : output_tensors) { + // Note: the tensor can have already been watched and hence be in the tape, + // so we cannot check that we're inserting it here. + tensor_tape_[o.id] = op_id; + tensor_usage_[o.id] = 1; + tensors.push_back(o); + } + op_tape_[op_id] = OpTapeEntry{ + op_type, tensors, ids, backward_function, backward_function_deleter}; +} + +template +void GradientTape::DeleteTrace(int64 tensor_id) { + auto it = tensor_usage_.find(tensor_id); + if (it == tensor_usage_.end()) { + return; + } + it->second--; + if (it->second != 0) { + return; + } + tensor_usage_.erase(it); + auto tensor_op_it = tensor_tape_.find(tensor_id); + if (tensor_op_it == tensor_tape_.end()) { + return; + } + const int64 op_id = tensor_op_it->second; + if (op_id == -1) { + // Do not delete watched tensors. + return; + } + tensor_tape_.erase(tensor_op_it); + auto op_it = op_tape_.find(op_id); + CHECK(op_it != op_tape_.end()); + for (const auto& output : op_it->second.output_tensor_info) { + if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + // Found a usage for an output, so cannot delete the op. + return; + } + } + for (int64 id : op_it->second.input_tensor_id) { + DeleteTrace(id); + } + op_it->second.backward_function_deleter(); + op_tape_.erase(op_it); +} + +// Terminology: +// +// - op: a possibly composite operation, which has an entry in the tape +// - target: dy in dx/dy +// - source: dx in dx/dy +// - tensor: one of the many inputs or outputs of an operation +// +// Below here we do the gradient algorithm. It works as follows: +// +// First we filter the tape to just the subset of operations we want to +// differentiate. In the process of doing so we count how many times each Tensor +// is used as an input to an op (so we know when we're done computing gradients +// for that Tensor). We also count, for each tape entry, how many of its output +// Tensors need gradients to be computed (Tensors which are not used do not need +// any gradients to be computed). +// +// Finally, we start a backprop stack with a set of tape entries for which we +// have all gradients available. This set usually is a subset of the set of +// targets (not all since targets which have outputs in the tape will not have +// gradients available initially). +// +// Then we repeatedly pop an entry from the stack, run its backprop, and update +// the gradients of its inputs. Once we have computed all gradients for a single +// input we can mark this input as done, and this can trigger adding an entry to +// the stack if all outputs of that entry are now done. +// +// When the stack is empty we have gradients for all tensors we're interested +// in. + +namespace { + +template +struct BackpropInitialState { + OpTape op_tape; + + // Map from tensor ID to how many references still exist for this tensor in + // the tape. + std::unordered_map tensor_usage_counts; + + // Maps from op ID to how many output tensors of this op still need to have + // their gradients computed. + std::unordered_map op_missing_tensor; }; +// If `persistent_tape` is true, op_tape is not changed and none of the +// backwards functions are deleted. +// If `persistent_tape` is false, op_tape is cleared and backwards functions +// not needed for gradient computation are deleted. Backwards functions that +// are needed, are copied and returned in BackpropInitialState. +template +BackpropInitialState PrepareBackprop( + gtl::ArraySlice target, const TensorTape& tensor_tape, + OpTape* op_tape, + const std::unordered_set& sources_set, bool persistent_tape) { + std::vector tensor_stack; + tensor_stack.reserve(target.size()); + for (auto t : target) { + tensor_stack.push_back(t); + } + BackpropInitialState result; + while (!tensor_stack.empty()) { + int64 tensor_id = tensor_stack.back(); + tensor_stack.pop_back(); + auto op_id_it = tensor_tape.find(tensor_id); + if (op_id_it == tensor_tape.end()) { + continue; + } + int64 op_id = op_id_it->second; + auto op_it = op_tape->find(op_id); + auto result_op_it = result.op_tape.find(op_id); + if (op_id == -1 || op_it == op_tape->end() || + result_op_it != result.op_tape.end()) { + continue; + } + CHECK(result.op_tape.emplace(op_id, op_it->second).second); + for (auto it : op_it->second.input_tensor_id) { + auto count_it = result.tensor_usage_counts.find(it); + if (count_it != result.tensor_usage_counts.end()) { + count_it->second++; + } else { + result.tensor_usage_counts[it] = 1; + if (sources_set.find(it) == sources_set.end() && + tensor_tape.find(it) != tensor_tape.end()) { + tensor_stack.push_back(it); + } + } + } + if (!persistent_tape) { + op_tape->erase(op_it); + } + } + for (auto& pair : result.tensor_usage_counts) { + auto it = tensor_tape.find(pair.first); + if (it != tensor_tape.end() && it->second != -1) { + result.op_missing_tensor[it->second] += 1; + } + } + if (!persistent_tape) { + // Call destructors for all unneeded gradient functions and + // clear the op_tape. We can clear the tape because ownership of + // backward functions that will be used for gradient computation + // has been transfered to `result`. + for (const auto& op_pair : *op_tape) { + op_pair.second.backward_function_deleter(); + } + op_tape->clear(); + } + return result; +} + +template +std::vector InitialStack( + const OpTape& op_tape, + const std::unordered_map& op_missing_tensor) { + std::vector result; + for (auto& op_entry : op_tape) { + if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) { + result.push_back(op_entry.first); + } + } + return result; +} + +template +Status InitialGradients( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, + const OpTape& op_tape, + const std::unordered_map& tensor_usage_counts, + std::unordered_map>* result) { + for (int i = 0; i < target_tensor_ids.size(); ++i) { + const int64 id = target_tensor_ids[i]; + if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) { + if (!output_gradients.empty() && output_gradients[i] != nullptr) { + // TODO(apassos) figure out how to print debugging information here. + return errors::InvalidArgument( + "A gradient was provided for a tensor which is used as part of the " + "computation."); + } + } else { + if (output_gradients.empty() || output_gradients[i] == nullptr) { + auto tensor_it = tensor_tape.find(id); + if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { + auto op_it = op_tape.find(tensor_it->second); + if (op_it == op_tape.end()) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "failed to find operation producing a tensor"); + } + bool found = false; + for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { + if (op_it->second.output_tensor_info[j].id == id) { + found = true; + (*result)[id].push_back( + vspace.Ones(op_it->second.output_tensor_info[j].shape, + op_it->second.output_tensor_info[j].dtype)); + break; + } + } + if (!found) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "none of operations outputs match expected tensor"); + } + } else { + // No record of the target tensor found on the tape, so no gradient + // needs to be computed from it. Do nothing. + } + } else { + (*result)[id].push_back(output_gradients[i]); + } + } + } + return Status::OK(); +} + +} // namespace + +// If over kMinAggregateCount gradients are accumulated and the total +// memory consumption is over kMinAggregateBytes, do an early aggregation +// so as to release the gradient tensor to save memory. +constexpr int kMinAggregateCount = 4; +constexpr int kMinAggregateBytes = 128 * 1024 * 1024; + +template +Status GradientTape::ComputeGradient( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice source_tensor_ids, + gtl::ArraySlice output_gradients, + std::vector* result) { + std::unordered_set sources_set(source_tensor_ids.begin(), + source_tensor_ids.end()); + BackpropInitialState state = PrepareBackprop( + target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); + std::vector op_stack = + InitialStack(state.op_tape, state.op_missing_tensor); + std::unordered_map> gradients; + Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, + tensor_tape_, state.op_tape, + state.tensor_usage_counts, &gradients); + auto cleanup = [this, &state]() { + if (!persistent_) { + // Release all backprop functions + for (const auto& pair : state.op_tape) { + pair.second.backward_function_deleter(); + } + } + }; + if (!s.ok()) { + cleanup(); + return s; + } + std::unordered_map gradients_size; + // TODO(apassos) multiple threads could be dequeuing from op_stack at the same + // time, for better CPU backprop performance. + VLOG(1) << "Initial stack:"; + if (VLOG_IS_ON(1)) { + for (auto t : op_stack) { + VLOG(1) << " " << t; + } + } + std::unordered_map> + functions_accept_none_for_indices({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); + while (!op_stack.empty()) { + const int64 op = op_stack.back(); + VLOG(1) << "Popped " << op; + op_stack.pop_back(); + auto op_it = state.op_tape.find(op); + if (op_it == state.op_tape.end()) { + // It is possible for ops to end up on the stack if they are unrelated to + // the target; we should just skip them. + continue; + } + auto trace = std::move(op_it->second); + state.op_tape.erase(op_it); + std::vector out_gradients; + out_gradients.reserve(trace.output_tensor_info.size()); + bool any_gradient_nonzero = false; + for (int i = 0; i < trace.output_tensor_info.size(); ++i) { + const int64 id = trace.output_tensor_info[i].id; + auto grad_it = gradients.find(id); + if (grad_it == gradients.end()) { + auto func_name_it = + functions_accept_none_for_indices.find(trace.op_type); + if (func_name_it != functions_accept_none_for_indices.end() && + func_name_it->second.find(i) != func_name_it->second.end()) { + out_gradients.push_back(nullptr); + } else { + out_gradients.push_back( + vspace.Zeros(trace.output_tensor_info[i].shape, + trace.output_tensor_info[i].dtype)); + } + } else { + any_gradient_nonzero = true; + out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); + if (sources_set.find(grad_it->first) == sources_set.end()) { + gradients.erase(grad_it); + } + } + } + std::vector in_gradients; + if (any_gradient_nonzero) { + Status s = vspace.CallBackwardFunction(trace.backward_function, + out_gradients, &in_gradients); + if (!persistent_) { + vspace.ReleaseBackwardFunction(trace.backward_function); + } + if (!s.ok()) { + cleanup(); + return s; + } + } else { + in_gradients.resize(trace.input_tensor_id.size()); + if (!persistent_) { + vspace.ReleaseBackwardFunction(trace.backward_function); + } + } + VLOG(1) << "Got " << in_gradients.size() << " in_gradients for " + << trace.input_tensor_id.size() << " sources"; + for (int i = 0; i < in_gradients.size(); ++i) { + const int64 id = trace.input_tensor_id[i]; + if (in_gradients[i] != nullptr) { + auto& unaggregated_grads = gradients[id]; + unaggregated_grads.push_back(in_gradients[i]); + if (unaggregated_grads.size() > kMinAggregateCount) { + auto size_it = gradients_size.find(id); + int64 size; + if (size_it == gradients_size.end()) { + size = vspace.NumElements(unaggregated_grads[0]); + gradients_size.emplace(id, size); + } else { + size = size_it->second; + } + if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) { + Gradient* grad = vspace.AggregateGradients(unaggregated_grads); + unaggregated_grads.clear(); + unaggregated_grads.push_back(grad); + } + } + } + auto usage_count_it = state.tensor_usage_counts.find(id); + if (usage_count_it == state.tensor_usage_counts.end()) { + VLOG(1) << "Tensor " << id << " not used"; + continue; + } + usage_count_it->second--; + if (usage_count_it->second > 0) { + VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second; + continue; + } + auto tape_it = tensor_tape_.find(id); + if (tape_it == tensor_tape_.end()) { + VLOG(1) << "Tensor " << id + << " has no associated op. Deleting gradient"; + auto grad_it = gradients.find(id); + if (grad_it != gradients.end()) { + for (auto g : grad_it->second) { + vspace.DeleteGradient(g); + } + gradients.erase(grad_it); + } + continue; + } + const int64 op_id = tape_it->second; + if (op_id == -1) { + VLOG(1) << "Tensor " << id << " is source"; + continue; + } + auto missing_it = state.op_missing_tensor.find(op_id); + if (missing_it != state.op_missing_tensor.end()) { + missing_it->second--; + VLOG(1) << "Op " << op_id << " missing " << missing_it->second + << " output gradients"; + if (missing_it->second == 0) { + op_stack.push_back(op_id); + } + } + } + } + CHECK(state.op_tape.empty()); + result->reserve(source_tensor_ids.size()); + for (auto is : source_tensor_ids) { + auto grad_it = gradients.find(is); + if (grad_it == gradients.end()) { + result->push_back(nullptr); + } else { + if (grad_it->second.size() == 1) { + result->push_back(grad_it->second[0]); + } else { + result->push_back(vspace.AggregateGradients(grad_it->second)); + } + gradients.erase(grad_it); + } + } + VLOG(1) << "Final gradients size: " << gradients.size(); + for (auto grad_pair : gradients) { + for (const auto& g : grad_pair.second) { + vspace.DeleteGradient(g); + } + } + return Status::OK(); +} + } // namespace eager } // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index c67007dca0a2d3e97d367ef0eae2335e5683d087..6e37cdb5f4beea53d4a2ded0705ae482d0bc2d68 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -22,6 +22,7 @@ namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { mutex_lock l(graph->mu); graph->graph.AddControlEdge(&input->node, &op->node); + RecordMutation(graph, *op, "adding control input"); } void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, @@ -36,18 +37,66 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, mutex_lock l(graph->mu); op->node.AddAttr(attr_name, attr_val); + RecordMutation(graph, *op, "setting attribute"); } void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { mutex_lock l(graph->mu); op->node.set_requested_device(device); + RecordMutation(graph, *op, "setting device"); } void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status) { mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&new_src.oper->node); + + if (ic->num_outputs() <= new_src.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Output index [", new_src.index, + "] is greater than the number of total outputs [", ic->num_outputs(), + "]."); + return; + } + tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index); + + tensorflow::shape_inference::InferenceContext* ic_dst = + graph->refiner.GetContext(&dst.oper->node); + if (ic_dst->num_inputs() <= dst.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Input index [", dst.index, + "] is greater than the number of total inputs [", ic_dst->num_inputs(), + "]."); + return; + } + if (!ic_dst->MergeInput(dst.index, shape)) { + status->status = tensorflow::errors::InvalidArgument( + "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape), + " and ", ic_dst->DebugString(ic_dst->input(dst.index)), "."); + return; + } status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, &dst.oper->node, dst.index); + + if (status->status.ok()) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst.oper, "updating input tensor"); + } +} + +void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { + mutex_lock l(graph->mu); + std::vector control_edges; + for (const Edge* edge : op->node.in_edges()) { + if (!edge->IsControlEdge()) continue; + control_edges.push_back(edge); + } + for (const Edge* edge : control_edges) { + graph->graph.RemoveControlEdge(edge); + } } } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index f54585b0a1034ff108202272a11416e34985959e..b51ef2b53122802fef598a26bd6f1843976f11b0 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -35,6 +35,8 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); +void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 80112f9b44b1d5fd65a7d47788b072dc47a2b29a..e354831d7d25af83c068a68a4f844056263a598c 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -421,6 +421,7 @@ tf_cc_test( tf_gen_op_wrappers_cc( name = "cc_ops", + api_def_srcs = ["//tensorflow/core:base_api_def"], op_lib_names = [ "array_ops", "audio_ops", @@ -525,6 +526,30 @@ cc_library_with_android_deps( "//tensorflow/core:android_tensorflow_lib", ], copts = tf_copts(), + data = [ + "//tensorflow/core:base_api_def", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", + "//tensorflow/core:op_gen_overrides_proto_cc", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "cc_op_gen_test", + srcs = [ + "framework/cc_op_gen.cc", + "framework/cc_op_gen.h", + "framework/cc_op_gen_test.cc", + ], + data = [ + "//tensorflow/cc:ops/op_gen_overrides.pbtxt", + ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -533,6 +558,8 @@ cc_library_with_android_deps( "//tensorflow/core:op_gen_overrides_proto_cc", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 38a17598b8e4161f96ab8134823de033d3284440..d889c518f9c38a9f070970b37a2ad4b1fc26671b 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -18,10 +18,11 @@ limitations under the License. #include #include "tensorflow/cc/framework/cc_op_gen.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/op_gen_lib.h" -#include "tensorflow/core/framework/op_gen_overrides.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb_text.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { - namespace { const int kRightMargin = 79; @@ -297,7 +297,7 @@ string ToCamelCase(const string& str) { // argument to a function. std::pair AttrTypeName(StringPiece attr_type) { static const std::unordered_map, - StringPiece::Hasher> + StringPieceHasher> attr_type_map{ {"string", {"StringPiece", false}}, {"list(string)", {"gtl::ArraySlice", true}}, @@ -325,29 +325,112 @@ std::pair AttrTypeName(StringPiece attr_type) { } bool IsCPPKeyword(StringPiece name) { - static const std::unordered_set + static const std::unordered_set // Keywords obtained from http://en.cppreference.com/w/cpp/keyword kCPPReserved{ - "alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", - "atomic_commit", "atomic_noexcept", "auto", "bitand", "bitor", "bool", - "break", "case", "catch", "char", "char16_t", "char32_t", "class", - "compl", "concept", "const", "const_cast", "constexpr", "continue", - "decltype", "default", "delete", "do", "double", "dynamic_cast", - "else", "enum", "explicit", "export", "extern", "false", "final", - "float", "for", "friend", "goto", "if", "import", "inline", "int", - "long", "module", "mutable", "namespace", "new", "noexcept", "not", - "not_eq", "nullptr", "operator", "or", "or_eq", "override", "private", - "protected", "public", "register", "reinterpret_cast", "requires", - "return", "short", "signed", "sizeof", "static", "static_assert", - "static_cast", "struct", "switch", "synchronized", "template", "this", - "thread_local", "throw", "true", "try", "typedef", "typeid", - "typename", "union", "unsigned", "using", "virtual", "void", - "volatile", "wchar_t", "while", "xor", "xor_eq", + "alignas", + "alignof", + "and", + "and_eq", + "asm", + "atomic_cancel", + "atomic_commit", + "atomic_noexcept", + "auto", + "bitand", + "bitor", + "bool", + "break", + "case", + "catch", + "char", + "char16_t", + "char32_t", + "class", + "compl", + "concept", + "const", + "const_cast", + "constexpr", + "continue", + "decltype", + "default", + "delete", + "do", + "double", + "dynamic_cast", + "else", + "enum", + "explicit", + "export", + "extern", + "false", + "final", + "float", + "for", + "friend", + "goto", + "if", + "import", + "inline", + "int", + "long", + "module", + "mutable", + "namespace", + "new", + "noexcept", + "not", + "not_eq", + "nullptr", + "operator", + "or", + "or_eq", + "override", + "private", + "protected", + "public", + "register", + "reinterpret_cast", + "requires", + "return", + "short", + "signed", + "sizeof", + "static", + "static_assert", + "static_cast", + "struct", + "switch", + "synchronized", + "template", + "this", + "thread_local", + "throw", + "true", + "try", + "typedef", + "typeid", + "typename", + "union", + "unsigned", + "using", + "virtual", + "void", + "volatile", + "wchar_t", + "while", + "xor", + "xor_eq", // The following are not C++ keywords, but names of local variables // and parameters used in the op constructor. Treating them as // keywords, so that other parameter names don't conflict with these. - "builder", "node", "ret", "scope", "unique_name", + "builder", + "node", + "ret", + "scope", + "unique_name", }; return kCPPReserved.count(name) > 0; } @@ -385,10 +468,10 @@ bool ArgIsList(const OpDef::ArgDef& arg) { } bool HasOptionalAttrs( - const OpDef& op_def, + const ApiDef& api_def, const std::unordered_map& inferred_input_attrs) { - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < api_def.attr_size(); ++i) { + const auto& attr(api_def.attr(i)); if ((inferred_input_attrs.find(attr.name()) == inferred_input_attrs.end()) && attr.has_default_value()) { @@ -398,12 +481,21 @@ bool HasOptionalAttrs( return false; } +const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { + for (int i = 0; i < api_def.in_arg_size(); ++i) { + if (api_def.in_arg(i).name() == name) { + return &api_def.in_arg(i); + } + } + return nullptr; +} + struct OpInfo { // graph_op_def: The OpDef used by the runtime, has the names that // must be used when calling NodeBuilder. // interface_op_def: The OpDef used in the interface in the generated // code, with possibly overridden names and defaults. - explicit OpInfo(const OpDef& graph_op_def, const OpDef& inteface_op_def, + explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, const std::vector& aliases); string GetOpAttrStruct() const; string GetConstructorDecl(StringPiece op_name_prefix, @@ -423,74 +515,81 @@ struct OpInfo { string comment; const OpDef& graph_op_def; - const OpDef& op_def; + const ApiDef& api_def; const std::vector& aliases; + // Map from type attribute to corresponding original argument name. std::unordered_map inferred_input_attrs; }; -OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def, - const std::vector& a) - : graph_op_def(g_op_def), op_def(i_op_def), aliases(a) { - op_name = op_def.name(); - InferOpAttributes(op_def, &inferred_input_attrs); - has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs); +OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, + const std::vector& aliases) + : graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) { + op_name = api_def.endpoint(0).name(); + InferOpAttributes(graph_op_def, &inferred_input_attrs); + has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs); arg_types.push_back("const ::tensorflow::Scope&"); arg_names.push_back("scope"); - if (op_def.has_deprecation()) { - if (!op_def.summary().empty()) { - comment = strings::StrCat(op_def.summary(), "\n"); + if (graph_op_def.has_deprecation()) { + if (!api_def.summary().empty()) { + comment = strings::StrCat(api_def.summary(), "\n"); } strings::StrAppend(&comment, "DEPRECATED at GraphDef version ", - op_def.deprecation().version(), ":\n", - op_def.deprecation().explanation(), ".\n"); - } else if (op_def.summary().empty()) { + graph_op_def.deprecation().version(), ":\n", + graph_op_def.deprecation().explanation(), ".\n"); + } else if (api_def.summary().empty()) { comment = "TODO: add doc.\n"; } else { - comment = strings::StrCat(op_def.summary(), "\n"); + comment = strings::StrCat(api_def.summary(), "\n"); } - if (!op_def.description().empty()) { - strings::StrAppend(&comment, "\n", op_def.description(), "\n"); + if (!api_def.description().empty()) { + strings::StrAppend(&comment, "\n", api_def.description(), "\n"); } strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n"); // Process inputs - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); + for (int i = 0; i < api_def.arg_order_size(); ++i) { + const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def); + const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def); arg_types.push_back(strings::StrCat( "::tensorflow::", ArgIsList(arg) ? "InputList" : "Input")); - arg_names.push_back(AvoidCPPKeywords(arg.name())); + arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to())); // TODO(keveman): Include input type information. - StringPiece description = arg.description(); + StringPiece description = api_def_arg.description(); if (!description.empty()) { ConsumeEquals(&description); - strings::StrAppend(&comment, "* ", AvoidCPPKeywords(arg.name()), ": ", - arg.description(), "\n"); + strings::StrAppend(&comment, "* ", + AvoidCPPKeywords(api_def_arg.rename_to()), ": ", + api_def_arg.description(), "\n"); } } // Process attrs string required_attrs_comment; string optional_attrs_comment; - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < graph_op_def.attr_size(); ++i) { + // ApiDef attributes must be in the same order as in OpDef since + // we initialize ApiDef based on OpDef. + const auto& attr(graph_op_def.attr(i)); + const auto& api_def_attr(api_def.attr(i)); + CHECK_EQ(attr.name(), api_def_attr.name()); // Skip inferred arguments if (inferred_input_attrs.count(attr.name()) > 0) continue; const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; - string attr_name = AvoidCPPKeywords(attr.name()); + string attr_name = AvoidCPPKeywords(api_def_attr.rename_to()); string attr_comment; - if (!attr.description().empty()) { + if (!api_def_attr.description().empty()) { // TODO(keveman): Word wrap and indent this, to handle multi-line // descriptions. strings::StrAppend(&attr_comment, "* ", attr_name, ": ", - attr.description(), "\n"); + api_def_attr.description(), "\n"); } - if (attr.has_default_value()) { + if (api_def_attr.has_default_value()) { strings::StrAppend(&optional_attrs_comment, attr_comment); } else { strings::StrAppend(&required_attrs_comment, attr_comment); @@ -508,44 +607,49 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def, } // Process outputs - for (int i = 0; i < op_def.output_arg_size(); ++i) { - const auto& arg = op_def.output_arg(i); + for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { + // ApiDef arguments must be in the same order as in OpDef since + // we initialize ApiDef based on OpDef. + const auto& arg = graph_op_def.output_arg(i); + const auto& api_def_arg(api_def.out_arg(i)); + CHECK_EQ(arg.name(), api_def_arg.name()); + bool is_list = ArgIsList(arg); output_types.push_back( strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output")); - output_names.push_back(AvoidCPPKeywords(arg.name())); + output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to())); is_list_output.push_back(is_list); } strings::StrAppend(&comment, "\nReturns:\n"); - if (op_def.output_arg_size() == 0) { // No outputs. + if (graph_op_def.output_arg_size() == 0) { // No outputs. strings::StrAppend(&comment, "* the created `Operation`\n"); - } else if (op_def.output_arg_size() == 1) { // One output + } else if (graph_op_def.output_arg_size() == 1) { // One output if (is_list_output[0]) { strings::StrAppend(&comment, "* `OutputList`: "); } else { strings::StrAppend(&comment, "* `Output`: "); } - if (op_def.output_arg(0).description().empty()) { - strings::StrAppend(&comment, "The ", op_def.output_arg(0).name(), + if (api_def.out_arg(0).description().empty()) { + strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(), " tensor.\n"); } else { // TODO(josh11b): Word wrap this. - strings::StrAppend(&comment, op_def.output_arg(0).description(), "\n"); + strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n"); } } else { // Multiple outputs. - for (int i = 0; i < op_def.output_arg_size(); ++i) { + for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { if (is_list_output[i]) { strings::StrAppend(&comment, "* `OutputList`"); } else { strings::StrAppend(&comment, "* `Output`"); } strings::StrAppend(&comment, " ", output_names[i]); - if (op_def.output_arg(i).description().empty()) { + if (api_def.out_arg(i).description().empty()) { strings::StrAppend(&comment, "\n"); } else { // TODO(josh11b): Word wrap this. - strings::StrAppend(&comment, ": ", op_def.output_arg(i).description(), + strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(), "\n"); } } @@ -564,19 +668,20 @@ string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < graph_op_def.attr_size(); ++i) { + const auto& attr(graph_op_def.attr(i)); + const auto& api_def_attr(api_def.attr(i)); // If attr will be inferred or it doesn't have a default value, don't // add it to the struct. if ((inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) || - !attr.has_default_value()) { + !api_def_attr.has_default_value()) { continue; } const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; - const string camel_case_name = ToCamelCase(attr.name()); + const string camel_case_name = ToCamelCase(api_def_attr.rename_to()); const string suffix = (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : ""; const string attr_func_def = @@ -584,22 +689,25 @@ string OpInfo::GetOpAttrStruct() const { attr_type_name, use_const ? "&" : ""); string attr_comment; - if (!attr.description().empty()) { - strings::StrAppend(&attr_comment, attr.description(), "\n\n"); + if (!api_def_attr.description().empty()) { + strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n"); } strings::StrAppend(&attr_comment, "Defaults to ", - SummarizeAttrValue(attr.default_value()), "\n"); + SummarizeAttrValue(api_def_attr.default_value()), "\n"); attr_comment = MakeComment(attr_comment, " "); strings::StrAppend(&setters, attr_comment); strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n"); strings::StrAppend(&setters, " Attrs ret = *this;\n"); - strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n"); + strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(), + "_ = x;\n"); strings::StrAppend(&setters, " return ret;\n }\n\n"); strings::StrAppend( - &struct_fields, " ", attr_type_name, " ", attr.name(), "_ = ", - PrintAttrValue(op_def.name(), attr.default_value()), ";\n"); + &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(), + "_ = ", + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), + ";\n"); } if (struct_fields.empty()) { @@ -676,17 +784,18 @@ void OpInfo::WriteClassDecl(WritableFile* h) const { // Add the static functions to set optional attrs if (has_optional_attrs) { strings::StrAppend(&class_decl, "\n"); - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < graph_op_def.attr_size(); ++i) { + const auto& attr(graph_op_def.attr(i)); + const auto& api_def_attr(api_def.attr(i)); if ((inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) || - !attr.has_default_value()) { + !api_def_attr.has_default_value()) { continue; } const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; - const string camel_case_name = ToCamelCase(attr.name()); + const string camel_case_name = ToCamelCase(api_def_attr.rename_to()); const string suffix = (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : ""; const string attr_func_def = strings::StrCat( @@ -726,11 +835,11 @@ void OpInfo::GetOutput(string* out) const { strings::StrCat("if (!", scope_str, ".ok()) return;"); // No outputs. - if (op_def.output_arg_size() == 0) { + if (graph_op_def.output_arg_size() == 0) { strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n"); return; } - if (op_def.output_arg_size() == 1) { + if (graph_op_def.output_arg_size() == 1) { // One output, no need for NameRangeMap if (is_list_output[0]) { strings::StrAppend(out, @@ -752,7 +861,7 @@ void OpInfo::GetOutput(string* out) const { ".UpdateStatus(_status_);\n", " return;\n"); strings::StrAppend(out, " }\n\n"); - for (int i = 0; i < op_def.output_arg_size(); ++i) { + for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { const string arg_range = strings::StrCat( "_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]"); if (is_list_output[i]) { @@ -776,11 +885,13 @@ string OpInfo::GetConstructorBody() const { strings::StrAppend(&body, " ", return_on_error, "\n"); - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); - strings::StrAppend(&body, " auto _", arg.name(), " = ::tensorflow::ops::", - ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", - scope_str, ", ", AvoidCPPKeywords(arg.name()), ");\n"); + for (int i = 0; i < graph_op_def.input_arg_size(); ++i) { + const auto& arg(graph_op_def.input_arg(i)); + const auto& api_def_arg(api_def.in_arg(i)); + strings::StrAppend( + &body, " auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::", + ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ", + AvoidCPPKeywords(api_def_arg.rename_to()), ");\n"); strings::StrAppend(&body, " ", return_on_error, "\n"); } @@ -791,19 +902,21 @@ string OpInfo::GetConstructorBody() const { &body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"", graph_op_def.name(), "\")\n"); const string spaces = " "; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); - strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n"); + for (int i = 0; i < api_def.in_arg_size(); ++i) { + const auto& arg(api_def.in_arg(i)); + strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n"); } - for (int i = 0; i < op_def.attr_size(); ++i) { + for (int i = 0; i < api_def.attr_size(); ++i) { const auto& graph_attr(graph_op_def.attr(i)); - const auto& attr(op_def.attr(i)); - if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) { + const auto& api_def_attr(api_def.attr(i)); + if (inferred_input_attrs.find(api_def_attr.name()) != + inferred_input_attrs.end()) { continue; } - const string attr_name = attr.has_default_value() - ? strings::StrCat("attrs.", attr.name(), "_") - : AvoidCPPKeywords(attr.name()); + const string attr_name = + api_def_attr.has_default_value() + ? strings::StrCat("attrs.", api_def_attr.rename_to(), "_") + : AvoidCPPKeywords(api_def_attr.rename_to()); strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ", attr_name, ")\n"); } @@ -845,10 +958,10 @@ void OpInfo::WriteClassDef(WritableFile* cc) const { TF_CHECK_OK(cc->Append(class_def)); } -void WriteCCOp(const OpDef& graph_op_def, const OpDef& interface_op_def, +void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def, const std::vector& aliases, WritableFile* h, WritableFile* cc) { - OpInfo op_info(graph_op_def, interface_op_def, aliases); + OpInfo op_info(graph_op_def, api_def, aliases); op_info.WriteClassDecl(h); op_info.WriteClassDef(cc); @@ -943,8 +1056,9 @@ string MakeInternal(const string& fname) { } // namespace -void WriteCCOps(const OpList& ops, const string& dot_h_fname, - const string& dot_cc_fname, const string& overrides_fnames) { +void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, + const string& dot_h_fname, const string& dot_cc_fname, + const string& overrides_fnames) { Env* env = Env::Default(); // Load the override map. @@ -984,24 +1098,23 @@ void WriteCCOps(const OpList& ops, const string& dot_h_fname, // code depends on it. if (graph_op_def.name() == "Const") continue; - // Incorporate overrides from override_map. - OpDef interface_op_def = graph_op_def; - const OpGenOverride* op_override = - override_map.ApplyOverride(&interface_op_def); + const auto* api_def = api_def_map.GetApiDef(graph_op_def.name()); + std::vector aliases; - if (op_override) { - if (op_override->skip()) continue; - aliases.assign(op_override->alias().begin(), op_override->alias().end()); - if (op_override->hide()) { - // Write hidden ops to _internal.h and _internal.cc. - WriteCCOp(graph_op_def, interface_op_def, aliases, internal_h.get(), - internal_cc.get()); - continue; - } + if (api_def->visibility() == ApiDef::SKIP) continue; + // First endpoint is canonical, the rest are aliases. + for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size(); + ++endpoint_i) { + aliases.push_back(api_def->endpoint(endpoint_i).name()); + } + if (api_def->visibility() == ApiDef::HIDDEN) { + // Write hidden ops to _internal.h and _internal.cc. + WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(), + internal_cc.get()); + continue; } - // This isn't a hidden op, write it to the main files. - WriteCCOp(graph_op_def, interface_op_def, aliases, h.get(), cc.get()); + WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get()); } FinishFiles(false, h.get(), cc.get(), op_header_guard); diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h index fa5e004f0317d046d82bee005bdf9f17773a45f3..cea28990144b9371e8009ce13f912b44044f9aac 100644 --- a/tensorflow/cc/framework/cc_op_gen.h +++ b/tensorflow/cc/framework/cc_op_gen.h @@ -17,13 +17,15 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { /// Result is written to files dot_h and dot_cc. -void WriteCCOps(const OpList& ops, const string& dot_h_fname, - const string& dot_cc_fname, const string& overrides_fnames); +void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, + const string& dot_h_fname, const string& dot_cc_fname, + const string& overrides_fnames); } // namespace tensorflow diff --git a/tensorflow/cc/framework/cc_op_gen_main.cc b/tensorflow/cc/framework/cc_op_gen_main.cc index 3b80cf993eb9a5d5f4c41687577414e7216dd174..326d5668b8803ee39ffe24900c92e1db87b93601 100644 --- a/tensorflow/cc/framework/cc_op_gen_main.cc +++ b/tensorflow/cc/framework/cc_op_gen_main.cc @@ -16,7 +16,11 @@ limitations under the License. #include "tensorflow/cc/framework/cc_op_gen.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/types.h" @@ -24,10 +28,28 @@ namespace tensorflow { namespace { void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, - const std::string& overrides_fnames, bool include_internal) { + const std::string& overrides_fnames, bool include_internal, + const std::vector& api_def_dirs) { OpList ops; OpRegistry::Global()->Export(include_internal, &ops); - WriteCCOps(ops, dot_h, dot_cc, overrides_fnames); + ApiDefMap api_def_map(ops); + if (!api_def_dirs.empty()) { + Env* env = Env::Default(); + // Only load files that correspond to "ops". + for (const auto& op : ops.op()) { + for (const auto& api_def_dir : api_def_dirs) { + const std::string api_def_file_pattern = + io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); + if (env->FileExists(api_def_file_pattern).ok()) { + TF_CHECK_OK(api_def_map.LoadFile(env, api_def_file_pattern)); + } + } + } + } + + api_def_map.UpdateDocs(); + + WriteCCOps(ops, api_def_map, dot_h, dot_cc, overrides_fnames); } } // namespace @@ -35,18 +57,24 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, int main(int argc, char* argv[]) { tensorflow::port::InitMain(argv[0], &argc, &argv); - if (argc != 5) { + // TODO(annarev): Update this file to no longer take op_gen_overrides.pbtxt + // as an argument. + if (argc != 6) { for (int i = 1; i < argc; ++i) { fprintf(stderr, "Arg %d = %s\n", i, argv[i]); } fprintf(stderr, - "Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal\n" + "Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal " + "api_def_dirs1,api_def_dir2 ...\n" " include_internal: 1 means include internal ops\n", argv[0]); exit(1); } bool include_internal = tensorflow::StringPiece("1") == argv[4]; - tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal); + std::vector api_def_dirs = tensorflow::str_util::Split( + argv[5], ",", tensorflow::str_util::SkipEmpty()); + tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal, + api_def_dirs); return 0; } diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b7e720a5c7b343415eee1aa157b8de755a1e1a5 --- /dev/null +++ b/tensorflow/cc/framework/cc_op_gen_test.cc @@ -0,0 +1,195 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/framework/cc_op_gen.h" + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// TODO(annarev): Remove this op_gen_overrides.pbtxt reference. +// It is needed only because WriteCCOps takes it as an argument. +constexpr char kOverridesFnames[] = + "tensorflow/cc/ops/op_gen_overrides.pbtxt"; +constexpr char kBaseOpDef[] = R"( +op { + name: "Foo" + input_arg { + name: "images" + description: "Images to process." + } + input_arg { + name: "dim" + description: "Description for dim." + type: DT_FLOAT + } + output_arg { + name: "output" + description: "Description for output." + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + description: "Type for images" + allowed_values { + list { + type: DT_UINT8 + type: DT_INT8 + } + } + default_value { + i: 1 + } + } + summary: "Summary for op Foo." + description: "Description for op Foo." +} +)"; + +void ExpectHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(s.contains(expected)) + << "'" << s << "' does not contain '" << expected << "'"; +} + +void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) { + EXPECT_FALSE(s.contains(expected)) + << "'" << s << "' contains '" << expected << "'"; +} + +void ExpectSubstrOrder(const string& s, const string& before, + const string& after) { + int before_pos = s.find(before); + int after_pos = s.find(after); + ASSERT_NE(std::string::npos, before_pos); + ASSERT_NE(std::string::npos, after_pos); + EXPECT_LT(before_pos, after_pos) + << before << " is not before " << after << " in " << s; +} + +// Runs WriteCCOps and stores output in (internal_)cc_file_path and +// (internal_)h_file_path. +void GenerateCcOpFiles(Env* env, const OpList& ops, + const ApiDefMap& api_def_map, string* h_file_text, + string* internal_h_file_text) { + const string& tmpdir = testing::TmpDir(); + + const auto h_file_path = io::JoinPath(tmpdir, "test.h"); + const auto cc_file_path = io::JoinPath(tmpdir, "test.cc"); + const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h"); + const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc"); + + WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, kOverridesFnames); + + TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text)); + TF_ASSERT_OK( + ReadFileToString(env, internal_h_file_path, internal_h_file_text)); +} + +TEST(CcOpGenTest, TestVisibilityChangedToHidden) { + const string api_def = R"( +op { + graph_op_name: "Foo" + visibility: HIDDEN +} +)"; + Env* env = Env::Default(); + OpList op_defs; + protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT + ApiDefMap api_def_map(op_defs); + + string h_file_text, internal_h_file_text; + // Without ApiDef + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(h_file_text, "class Foo"); + ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo"); + + // With ApiDef + TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(internal_h_file_text, "class Foo"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo"); +} + +TEST(CcOpGenTest, TestArgNameChanges) { + const string api_def = R"( +op { + graph_op_name: "Foo" + arg_order: "dim" + arg_order: "images" +} +)"; + Env* env = Env::Default(); + OpList op_defs; + protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT + + ApiDefMap api_def_map(op_defs); + string cc_file_text, h_file_text; + string internal_cc_file_text, internal_h_file_text; + // Without ApiDef + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectSubstrOrder(h_file_text, "Input images", "Input dim"); + + // With ApiDef + TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectSubstrOrder(h_file_text, "Input dim", "Input images"); +} + +TEST(CcOpGenTest, TestEndpoints) { + const string api_def = R"( +op { + graph_op_name: "Foo" + endpoint { + name: "Foo1" + } + endpoint { + name: "Foo2" + } +} +)"; + Env* env = Env::Default(); + OpList op_defs; + protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT + + ApiDefMap api_def_map(op_defs); + string cc_file_text, h_file_text; + string internal_cc_file_text, internal_h_file_text; + // Without ApiDef + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(h_file_text, "class Foo {"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo1"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo2"); + + // With ApiDef + TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(h_file_text, "class Foo1"); + ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo {"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 09fadfcab51575798286876f9a4e0ee9a60940ac..13a3bba5e6d5ca19ff3f0eca76665ba7d3ab628d 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -196,6 +196,18 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper); +Status LRNGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs){ + internal::LRNGrad::Attrs grad_attrs; + + auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0), + grad_attrs); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LRN", LRNGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index ac66f51cf01911957722e94ca28e8e78dc6de2ed..f9063e836509669d81d03b1d2f0d32d1166b6eca 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -191,5 +191,12 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { RunTest(x, x_init_value, y, y_shape); } +TEST_F(NNGradTest, LRN){ + TensorShape x_shape({1, 1, 2, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = LRN(scope_, x); + RunTest(x, x_shape, y, x_shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/tag_constants.h b/tensorflow/cc/saved_model/tag_constants.h index 2b0b2d5c7fb33768494c1781669c1adcb875a579..b71cb263ca42dab7e830c1880ec4b311bc272f82 100644 --- a/tensorflow/cc/saved_model/tag_constants.h +++ b/tensorflow/cc/saved_model/tag_constants.h @@ -21,6 +21,9 @@ namespace tensorflow { /// Tag for the `gpu` graph. constexpr char kSavedModelTagGpu[] = "gpu"; +/// Tag for the `tpu` graph. +constexpr char kSavedModelTagTpu[] = "tpu"; + /// Tag for the `serving` graph. constexpr char kSavedModelTagServe[] = "serve"; diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index a9a6ea84319a18a8fbce648391bf5918ff6d9a08..5740c040e309bad8d7e3bdc468c09a3323fb99e0 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -24,7 +24,6 @@ tf_cc_test( srcs = ["runtime_test.cc"], deps = [ ":runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -111,6 +110,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index ae22f7edc423247b34895411d19d7a3c21f86d4f..53da2881b60db9ad39565567623eb86f754559af 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -101,21 +101,8 @@ Status ComputeArgSizes(const CompileResult& compile_result, std::vector* arg_sizes) { const xla::ProgramShape& ps = compile_result.program_shape; for (int i = 0; i < ps.parameters_size(); ++i) { - if (i == ps.parameters_size() - 1 && compile_result.has_context_arg) { - // If the compiled function needs a XlaLocalRuntimeContext* arg, it's - // always last, and must be represented as an opaque type. - const xla::PrimitiveType type = ps.parameters(i).element_type(); - if (type != xla::OPAQUE) { - return errors::InvalidArgument( - "expected final context arg to be opaque, but got type: ", - xla::PrimitiveType_Name(type), ", from program shape: ", - xla::ShapeUtil::HumanString(ps)); - } - arg_sizes->push_back(-1); - } else { - arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( - ps.parameters(i), compile_result.pointer_size)); - } + arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( + ps.parameters(i), compile_result.pointer_size)); } return Status::OK(); } @@ -165,11 +152,6 @@ string RewriteWithName(const string& name, string code, Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); - if (compile_result.has_context_arg) { - // If the compiled function needs a XlaLocalRuntimeContext* arg, it's - // always last, and is set in the class constructor. - num_args--; - } if (config.feed_size() != num_args) { return errors::InvalidArgument("mismatch between feed_size(", config.feed_size(), ") and num_args(", @@ -418,7 +400,7 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); {{NS_START}} // {{CLASS}} represents a computation previously specified in a @@ -474,7 +456,6 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { data->temp_sizes = TempSizes(); data->num_temps = kNumTemps; data->result_index = kResultIndex; - data->requires_runtime_context = {{HAS_CONTEXT_ARG}}; data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); @@ -483,7 +464,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} {{CLASS}}(const {{CLASS}}&) = delete; @@ -496,8 +477,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. @@ -560,8 +541,6 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, {"{{CLASS}}", opts.class_name}, {"{{ENTRY}}", compile_result.entry_point}, - {"{{HAS_CONTEXT_ARG}}", - compile_result.has_context_arg ? "true" : "false"}, {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 0f6114666fcc89c631434527d2ae8c92c039ffea..75026c57c04a64186a1e5be6c41e4dd7de8520b7 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -145,11 +145,9 @@ TEST(GenerateHeader, Golden) { { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), - xla::ShapeUtil::MakeOpaqueShape(), }, xla::ShapeUtil::MakeTupleShape( {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); - compile_result.has_context_arg = true; compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; string header; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 65f342ce27ef09092f252f791973f245a8cdd6f3..35e50433d63a549bc6fb6a2be9015d7c471509d0 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -19,7 +19,7 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); namespace foo { namespace bar { @@ -48,7 +48,7 @@ namespace bar { // is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> (u32[5,6]) +// ((unknown): f32[1,2], (unknown): s64[3,4]) -> (u32[5,6]) // // Memory stats: // arg bytes total: 104 @@ -58,11 +58,11 @@ namespace bar { class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. - static constexpr size_t kNumArgs = 3; + static constexpr size_t kNumArgs = 2; // Byte size of each argument buffer. There are kNumArgs entries. static const intptr_t* ArgSizes() { - static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96, -1}; + static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96}; return kArgSizes; } @@ -77,7 +77,6 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { data->temp_sizes = TempSizes(); data->num_temps = kNumTemps; data->result_index = kResultIndex; - data->requires_runtime_context = true; data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); @@ -86,7 +85,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} MyClass(const MyClass&) = delete; @@ -99,8 +98,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. @@ -236,8 +235,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // Shape of the args and results. static const xla::ProgramShape* StaticProgramShape() { static const xla::ProgramShape* kShape = []() { - static const char kProto[] = {10,12,16,11,26,2,1,2,42,4,10,2,1,0,10,12,16,5,26,2,3,4,42,4,10,2,1,0,10,2,16,14,18,16,16,13,34,12,16,8,26,2,5,6,42,4,10,2,1,0}; - static constexpr int kProtoSize = 50; + static const char kProto[] = {10,12,16,11,26,2,1,2,42,4,10,2,1,0,10,12,16,5,26,2,3,4,42,4,10,2,1,0,18,16,16,13,34,12,16,8,26,2,5,6,42,4,10,2,1,0}; + static constexpr int kProtoSize = 46; xla::ProgramShape* shape = new xla::ProgramShape; shape->ParseFromArray(kProto, kProtoSize); return shape; diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 2b8cc6024cb85e4f6269313927ff66d1d9a1cf79..c87f2b75dfa18ad5c3eda4bd6fcbcb3083ef73fd 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -94,9 +94,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) .ValueOrDie(); xla::Computation computation; - TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client, - &computation, - &compile_result->has_context_arg)); + TF_RETURN_IF_ERROR( + ConvertGraphDefToXla(graph_def, config, client, &computation)); if (!flags.out_session_module.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 965c2960816b3acc8d2209e6824d88647de0ce14..e03c5b1aa77c1262ed903aae3072ef65f34d80a2 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -34,7 +34,6 @@ struct CompileResult { // Contains object file and meta-info. std::unique_ptr aot; xla::ProgramShape program_shape; // Static shape of args and results. - bool has_context_arg = false; // Is last arg XlaLocalRuntimeContext? string entry_point; // Name of generated function. int pointer_size = 0; // Size of a pointer in bytes. }; diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc index ac79c278c1fdf8b6aedcb52121c767b8ba0ad358..6d603a02eb4ceade6832ba67b2981814ee25327a 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/aot/runtime_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 6b037f276ad1d6771b904bb970f45f32ae9531b8..413efd9cea3b6f71574615ad9ca92471ff925781 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -70,7 +70,7 @@ TEST(TFCompileTest, Add) { // Run tests that use set_argN_data separately, to avoid accidentally re-using // non-existent buffers. TEST(TFCompileTest, Add_SetArg) { - AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); int32 arg_x = 10; int32 arg_y = 32; @@ -258,7 +258,7 @@ TEST(TFCompileTest, MatMul2_SetArg) { Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); foo::bar::MatMulComp matmul( - foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); matmul.set_thread_pool(&device); // Test using the set_argN_data() methods. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 1e22b760b8a4189165a59ac307374277474bbc31..542451ed2d14fbceca00c6ccb6e28c1c3a0d4321 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -152,7 +152,7 @@ def tf_library(name, graph, config, " --target_triple=" + target_llvm_triple() + " --out_header=$(@D)/" + header_file + " --out_object=$(@D)/" + object_file + - flags), + " " + flags), tools=[tfcompile_tool], visibility=visibility, testonly=testonly, @@ -189,7 +189,7 @@ def tf_library(name, graph, config, " --cpp_class=" + cpp_class + " --target_triple=" + target_llvm_triple() + " --out_session_module=$(@D)/" + session_module_pb + - flags), + " " + flags), tools=[tfcompile_tool], visibility=visibility, testonly=testonly, @@ -267,7 +267,6 @@ def tf_library(name, graph, config, srcs=[test_file], deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", "@org_tensorflow//tensorflow/compiler/aot:runtime", "@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main", "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", @@ -313,7 +312,6 @@ def tf_library(name, graph, config, linkopts = if_android(["-pie", "-s"]), deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", "@org_tensorflow//tensorflow/compiler/aot:benchmark", "@org_tensorflow//tensorflow/compiler/aot:runtime", "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index bf7d9cf14d10f41aa48ea594a8d63db97b9973e1..026a1bf879d373fd0f5f4444b3ce10d01702f82b 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -251,6 +251,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 22899ebeebc929055518893b358f7950d380d6f6..dc06b7a4025ddc83bf766b702036297203c16e55 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -48,6 +49,52 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; namespace { +bool AreAllParentsConst(const Node& n, + const gtl::FlatSet& runtime_const_nodes) { + if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") { + // If the current node is itself a cast-to-const, no need + // to look at the incoming edges. + return true; + } + + bool all_parents_const = true; + bool atleast_one_non_control_edge = false; + for (const Edge* in : n.in_edges()) { + atleast_one_non_control_edge = + atleast_one_non_control_edge || !in->IsControlEdge(); + if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) { + all_parents_const = false; + break; + } + } + return all_parents_const && atleast_one_non_control_edge; +} + +void MarkGuaranteedConstants( + const Graph& graph, + const std::vector>& src_arg_pairs) { + gtl::FlatSet guaranteed_const_nodes; + std::vector srcs; + srcs.reserve(src_arg_pairs.size()); + for (const auto& src_arg : src_arg_pairs) { + srcs.push_back(src_arg.first); + } + ReverseDFSFrom(graph, srcs, /*enter=*/nullptr, + /*leave=*/[&guaranteed_const_nodes](Node* n) { + // TODO(vinuraja): Doesn't work in the presence of loops. + if (AreAllParentsConst(*n, guaranteed_const_nodes)) { + guaranteed_const_nodes.insert(n); + } + }); + + for (auto& src_arg : src_arg_pairs) { + if (guaranteed_const_nodes.count(src_arg.first) != 0) { + VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString(); + src_arg.second->AddAttr("_is_guaranteed_constant", true); + } + } +} + // A node/slot pair. // TODO(phawkins): is there a common definition of this? struct NodeSlot { @@ -175,9 +222,11 @@ Status Encapsulator::SplitIntoSubgraphs() { // Map from input graph nodes to subgraph nodes. std::unordered_map node_images; + std::vector> src_arg_pairs; // Copy all marked nodes to a subgraph. Do nothing for unmarked nodes. for (Node* node : graph_in_->op_nodes()) { string func_id = GetFunctionNameAttr(node); + if (func_id.empty()) continue; Subgraph& subgraph = subgraphs_[func_id]; @@ -276,11 +325,13 @@ Status Encapsulator::SplitIntoSubgraphs() { kArgOp); builder.Attr("T", dtype); builder.Attr("index", arg_index); + s = builder.Finalize(&arg_def); if (!s.ok()) return s; Node* arg = dst_subgraph.graph->AddNode(arg_def, &s); if (!s.ok()) return s; + src_arg_pairs.push_back({edge->src(), arg}); dst_subgraph.args.push_back(arg); } @@ -292,6 +343,8 @@ Status Encapsulator::SplitIntoSubgraphs() { } } + MarkGuaranteedConstants(*graph_in_, src_arg_pairs); + for (auto& entry : subgraphs_) { FixupSourceAndSinkEdges(entry.second.graph.get()); } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 4a1dbaf05dc7824835f3567c6abcf48222720230..717efb360185f1ce26ee1e9adb0ee5bf7f4799f8 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -398,5 +398,109 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } +const Node* FindNodeByName(const Graph& graph, const string& name) { + for (const Node* node : graph.nodes()) { + if (node->name() == name) return node; + } + return nullptr; +} + +bool HasGuaranteeConstAttr(const Node& n) { + bool is_guaranteed_constant = false; + if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", + &is_guaranteed_constant) + .ok()) { + return false; + } + return is_guaranteed_constant; +} + +TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); + auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f); + auto const_guarantee_x1 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); + auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2); + add1.node()->AddAttr("_encapsulate", "encapsulate1"); + + Graph graph_before(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before)); + + std::unique_ptr graph_after; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + int guaranteed_consts = 0; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_encapsulate", graph_before, + /*rewrite_subgraph_fn=*/ + [&guaranteed_consts](std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + for (const Node* n : graph->nodes()) { + if (n->type_string() == "_Arg" && + StringPiece(n->name()).starts_with("const")) { + ++guaranteed_consts; + EXPECT_TRUE(HasGuaranteeConstAttr(*n)); + } else { + EXPECT_FALSE(HasGuaranteeConstAttr(*n)); + } + } + return Status::OK(); + }, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_after, &library)); + EXPECT_EQ(2, guaranteed_consts); +} + +TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); + auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); + auto const_guarantee_x1 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); + auto const_guarantee_x2 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2); + auto const_guarantee_add1 = ops::Add(root.WithOpName("const_guarantee_add1"), + const_guarantee_x1, const_guarantee_x2); + auto add2 = ops::Add(root.WithOpName("add2"), const_guarantee_x1, x2); + auto mul1 = ops::Mul(root.WithOpName("mul1"), const_guarantee_add1, add2); + mul1.node()->AddAttr("_encapsulate", "encapsulate1"); + + Graph graph_before(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before)); + + std::unique_ptr graph_after; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + int guaranteed_consts = 0; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_encapsulate", graph_before, + /*rewrite_subgraph_fn=*/ + [&guaranteed_consts](std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + for (const Node* n : graph->nodes()) { + if (n->type_string() == "_Arg" && + StringPiece(n->name()).starts_with("const")) { + ++guaranteed_consts; + EXPECT_TRUE(HasGuaranteeConstAttr(*n)); + } else { + EXPECT_FALSE(HasGuaranteeConstAttr(*n)); + } + } + return Status::OK(); + }, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_after, &library)); + // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const + // and another non-const, so overall non-const. + EXPECT_EQ(1, guaranteed_consts); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 459a582e157f5ddc63997ca93e7c0294293517d3..9bea5663319c8a25249fdc265cee0191556a7c04 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -16,7 +16,6 @@ cc_library( "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 27c5da08c112664d361b5f969d100eed7b9df65c..39a770ab7b9ae56bd24865b86c69331b0a38ccec 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -103,7 +102,6 @@ xla::StatusOr XlaAllocator::Allocate( } void* data = reinterpret_cast(const_cast(t.tensor_data().data())); - TF_RET_CHECK(data != nullptr); tensors_[data] = t; return gpu::DeviceMemoryBase(data, size); } @@ -111,7 +109,6 @@ xla::StatusOr XlaAllocator::Allocate( Status XlaAllocator::RegisterArgument(const Tensor* t) { void* data = reinterpret_cast(const_cast(t->tensor_data().data())); - TF_RET_CHECK(data != nullptr); tensors_[data] = *t; return Status::OK(); } @@ -257,7 +254,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); - options.local_executable_has_hybrid_result = true; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; @@ -268,7 +264,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // Builds an XLA allocator for the device. XlaAllocator xla_allocator(client->platform(), ctx); - XlaLocalRuntimeContext local_runtime_context; std::unique_ptr output; // Build xla::ShapedBuffers that point directly to the Tensor buffers. @@ -301,18 +296,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t)); } - // Make the final parameter point at local_runtime_context. - if (kernel->requires_runtime_context) { - gpu::DeviceMemoryBase local_runtime_context_dmem( - &local_runtime_context, sizeof(local_runtime_context)); - arg_buffers.push_back( - xla::ShapedBuffer::MakeArrayShapedBuffer( - xla::ShapeUtil::MakeOpaqueShape(), client->platform(), - client->default_device_ordinal(), local_runtime_context_dmem) - .ConsumeValueOrDie()); - arg_ptrs.push_back(arg_buffers.back().get()); - } - // Execute the computation. VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; @@ -324,12 +307,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { auto run_result = executable->Run(arg_ptrs, run_options); OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - if (local_runtime_context.error) { - ctx->CtxFailure(errors::InvalidArgument("Compiled kernel returned error: ", - local_runtime_context.error_msg)); - return; - } - output = run_result.ConsumeValueOrDie()->release(); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc index 09aee39d8cd0e910320674fcfd8a7884ce2fdd04..4bc209b7ecf499d82e7567f7eff12b17cefa9863 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc @@ -39,21 +39,23 @@ static void AllocateFlags() { flags->tf_xla_min_cluster_size = 2; flags->tf_xla_max_cluster_size = std::numeric_limits::max(); flags->tf_xla_clustering_debug = false; - flag_list = new std::vector({ - Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, - "Control compilation of operators into XLA computations on CPU and " - "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " - "things very likely to be improved; 2 = on for everything. " - "Experimental."), - Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, - "Minimum number of operators in an XLA compilation. Ignored for " - "operators placed on an XLA device or operators explicitly marked " - "for compilation."), - Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, - "Maximum number of operators in an XLA compilation."), - Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, - "Dump graphs during XLA compilation."), - }); + flags->tf_xla_cpu_global_jit = false; + flag_list = new std::vector( + {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, + "Control compilation of operators into XLA computations on CPU and " + "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " + "things very likely to be improved; 2 = on for everything. " + "Experimental."), + Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, + "Minimum number of operators in an XLA compilation. Ignored for " + "operators placed on an XLA device or operators explicitly marked " + "for compilation."), + Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, + "Maximum number of operators in an XLA compilation."), + Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, + "Dump graphs during XLA compilation."), + Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit, + "Enables global JIT compilation for CPU via SessionOptions.")}); xla::legacy_flags::ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h index 24f80507428b6742c64d3d7e96e4b1c540eda01b..e1ccd7ddb8706ca445b6811ca1fec369af7cd5d5 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h @@ -46,6 +46,8 @@ typedef struct { int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA // compilation. bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. + bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU + // via SessionOptions. } MarkForCompilationPassFlags; // Return a pointer to the MarkForCompilationPassFlags struct; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 78d0aa86a8fae9a0c6035bdc579ef800337df917..1f311a3aedbf7711ce6a081671f5848a81f2bd85 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -172,10 +172,15 @@ bool HasResourceInputOrOutput(const Node& node) { DT_RESOURCE) != node.output_types().end(); } +struct NodeCompare { + bool operator()(const Node* a, const Node* b) { return a->id() < b->id(); } +}; +using OrderedNodeSet = std::set; + Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, - std::unordered_set* candidates) { + OrderedNodeSet* candidates) { OptimizerOptions opts; std::unique_ptr pflr( new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, @@ -210,6 +215,13 @@ Status FindCompilationCandidates( !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { continue; } + // _Retval nodes in a top-level function represent fetches. + // Do not compile them. + if (node->type_string() == "_Retval") { + VLOG(2) << "Compilation rejected node: return value " << node->name() + << ": " << node->type_string(); + continue; + } candidates->insert(node); } return Status::OK(); @@ -290,9 +302,11 @@ Status MarkForCompilationPass::Run( global_jit_level = static_cast(flags->tf_xla_auto_jit); } + bool cpu_global_jit = flags->tf_xla_cpu_global_jit; const FunctionLibraryDefinition* fld = options.flib_def; - auto is_compilable = [global_jit_level, fld](const Node* node, - const DeviceType& device_type) { + + auto is_compilable = [global_jit_level, cpu_global_jit, fld]( + const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { @@ -315,7 +329,11 @@ Status MarkForCompilationPass::Run( if (status.ok()) return compile; // Otherwise use the value of global_jit_level. - return registration->enable_jit_by_default && global_jit_level > 0; + // Ignore enable_jit_by_default if global jit compilation for CPU + // is explicitly requested via tf_xla_cpu_global_jit flag + bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; + return (ignore_registration || registration->enable_jit_by_default) && + global_jit_level > 0; }; return RunImpl(options, is_compilable); } @@ -341,7 +359,7 @@ Status MarkForCompilationPass::RunImpl( Graph* graph = options.graph->get(); - std::unordered_set compilation_candidates; + OrderedNodeSet compilation_candidates; TF_RETURN_IF_ERROR(FindCompilationCandidates( *graph, options.flib_def, (options.session_options != nullptr) ? options.session_options->env @@ -556,6 +574,7 @@ Status MarkForCompilationPass::RunImpl( if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation || registration->requires_compilation) { string& name = cluster_names[cluster]; + if (name.empty()) { name = strings::StrCat("cluster_", cluster_sequence_num++); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index b3d258aea177fbefa4bae51d8156da2ff86c9032..454f0aeae98d7afd51f12b2cfb1810de275a57f7 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -525,5 +525,32 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { "+-- c\n")); } +TEST(XlaCompilationTest, Retval) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::UnaryOp("_Retval", b, + builder.opts() + .WithName("R") + .WithAttr("T", DT_FLOAT) + .WithAttr("index", 0)); + + TF_EXPECT_OK(builder.ToGraph(graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + + EXPECT_EQ(2, clusters.size()); + EXPECT_TRUE(clusters.find("R") == clusters.cend()); + EXPECT_EQ(clusters["A"], clusters["B"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 23368b6c76a363882956577a20c1bd041211d234..3717c2cc24283e0b218f92ec820d16893cbe0c35 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -214,23 +214,15 @@ Status XlaCompilationCache::BuildExecutable( const XlaCompiler::CompilationResult& result, std::unique_ptr* executable) { VLOG(2) << "Compiling to local executable"; - xla::Shape opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); std::vector argument_layouts( result.xla_input_shapes.size()); for (int i = 0; i < result.xla_input_shapes.size(); ++i) { argument_layouts[i] = &result.xla_input_shapes[i]; } - if (result.requires_runtime_context) { - // The final arg is the XlaLocalRuntimeContext*. - argument_layouts.push_back(&opaque_shape); - } xla::ExecutableBuildOptions build_options; build_options.set_device_ordinal(client_->default_device_ordinal()); - build_options.set_platform(client_->platform()); build_options.set_result_layout(result.xla_output_shape); - build_options.set_has_hybrid_result( - options.local_executable_has_hybrid_result); auto compile_result = client_->Compile(*result.computation, argument_layouts, build_options); diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index fed2c92d763c33aad3c5b3f07c1f33364c797793..c936222f32056e92efced82d5adb3a96c8041a17 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -71,12 +71,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, void* dst_ptr = DMAHelper::base(device_tensor); se::DeviceMemoryBase dev_dst_ptr(dst_ptr, total_bytes); - Status status = Status::OK(); + Status status; stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. - if (!stream_->BlockHostUntilDone()) { + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p", stream_); + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); } done(status); @@ -105,12 +107,14 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, se::DeviceMemoryBase dev_src_ptr(src_ptr, total_bytes); void* dst_ptr = DMAHelper::base(cpu_tensor); - Status status = Status::OK(); + Status status; stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. - if (!stream_->BlockHostUntilDone()) { + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p", stream_); + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); } done(status); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 284ecbf97d3234f0159ba2bb807c976e2b5c2ac2..4f458ecff8f6523a23ca59e0cecb485a7988efad 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -129,6 +129,21 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "cholesky_op_test", + size = "small", + srcs = ["cholesky_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "clustering_test", size = "small", @@ -264,6 +279,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "image_ops_test", + size = "small", + srcs = ["image_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:image_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "lrn_ops_test", size = "medium", @@ -352,7 +380,15 @@ tf_xla_py_test( size = "small", srcs = ["random_ops_test.py"], # TODO(b/31361304): enable RNG ops on GPU when parallelized. - disabled_backends = ["gpu"], + disabled_backends = [ + "gpu", + "cpu", + ], + tags = [ + "manual", + "no_oss", + "notap", + ], deps = [ ":xla_test", "//tensorflow/python:framework_for_generated_wrappers", @@ -401,6 +437,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "scan_ops_test", + size = "small", + srcs = ["scan_ops_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "segment_reduction_ops_test", size = "medium", @@ -442,6 +492,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "stateless_random_ops_test", + size = "small", + srcs = ["stateless_random_ops_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/contrib/stateless", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "tensor_array_ops_test", size = "small", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index d412c572ae16b84c2434819aa0a2d881defef5f9..654dc15e86b21c7742d49281d53c1a75e6a45d3b 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -366,16 +366,52 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._real_div, - np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype), - np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype), + np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype), + np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype), + expected=np.array( + [1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2], + dtype=dtype)) + + # Test inf/nan scenarios. + self._testBinary( + gen_math_ops._real_div, + np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype), + np.array([0, 0, 0, 0, 0, 0], dtype=dtype), expected=np.array( [ - 1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2, - float("inf") + dtype(1 + 1j) / 0, + dtype(1) / 0, + dtype(1j) / 0, + dtype(-1) / 0, + dtype(-1j) / 0, + dtype(1 - 1j) / 0 ], dtype=dtype)) - # TODO(b/65408531): support+test pow for cplx + atan2_supported = self.device == "XLA_GPU" + if atan2_supported: + self._testBinary( + math_ops.pow, + dtype(3 + 2j), + dtype(4 - 5j), + expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) + self._testBinary( # empty rhs + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[0, 2], dtype=dtype), + expected=np.zeros(shape=[0, 2], dtype=dtype)) + self._testBinary( # to zero power + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[1, 2], dtype=dtype), + expected=np.ones(shape=[1, 2], dtype=dtype)) + lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) + rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) + scalar = dtype(2 + 2j) + self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) + self._testBinary( + math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) + self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) @@ -385,7 +421,9 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) - # TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow) + if atan2_supported: + self._testBinary( + gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) self._testBinary( gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 5e06f9a72401935b9681c35a164b51f50a8538ae..035cdea1786d39f3d21bb63be5c8ccffe1608bdf 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -35,6 +35,9 @@ from tensorflow.python.platform import googletest class CategoricalTest(XLATestCase): """Test cases for random-number generating operators.""" + def output_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + def _chi2(self, expected, actual): """Returns Chi2 GOF statistic.""" actual = np.asarray(actual) @@ -55,7 +58,8 @@ class CategoricalTest(XLATestCase): """ with self.test_session() as sess, self.test_scope(): random_seed.set_random_seed(1618) - op = random_ops.multinomial(logits, num_samples) + op = random_ops.multinomial(logits, num_samples, + output_dtype=dtypes.int32) d = sess.run(op) batch_size, num_classes = logits.shape @@ -73,11 +77,11 @@ class CategoricalTest(XLATestCase): return freqs_mat - def _testRngIsNotConstant(self, rng, dtype): + def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. with self.test_session() as sess: with self.test_scope(): - x = rng(dtype) + x = rng(dtype, output_dtype) # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. @@ -92,21 +96,25 @@ class CategoricalTest(XLATestCase): (not np.array_equal(y, w))) def testCategoricalIsNotConstant(self): - def rng(unused_dtype): - return random_ops.multinomial([[1., 1., 1.]], 10) + def rng(dtype, output_dtype): + return random_ops.multinomial(np.array([[1., 1., 1.]], dtype=dtype), 10, + output_dtype=output_dtype) - dtype = dtypes.float32 - self._testRngIsNotConstant(rng, dtype) + dtype = np.float32 + for output_dtype in self.output_dtypes(): + self._testRngIsNotConstant(rng, dtype, output_dtype) def testCategoricalIsInRange(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session() as sess: - with self.test_scope(): - x = random_ops.multinomial( - array_ops.ones(shape=[1, 20], dtype=dtype), 1000) - y = sess.run(x) - self.assertTrue((y >= 0).sum() == 1000) - self.assertTrue((y < 20).sum() == 1000) + for dtype in self.float_types: + for output_dtype in self.output_dtypes(): + with self.test_session() as sess: + with self.test_scope(): + x = random_ops.multinomial( + array_ops.ones(shape=[1, 20], dtype=dtype), 1000, + output_dtype=output_dtype) + y = sess.run(x) + self.assertTrue((y >= 0).sum() == 1000) + self.assertTrue((y < 20).sum() == 1000) def testSamplingCorrectness(self): np.random.seed(1618) # Make it reproducible. diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5010fe5e21d0782e68d4e6d5bf6b4df1b44793a3 --- /dev/null +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -0,0 +1,126 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.Cholesky.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class CholeskyOpTest(XLATestCase): + + def _verifyCholeskyBase(self, sess, placeholder, x, chol, verification, atol): + chol_np, verification_np = sess.run([chol, verification], {placeholder: x}) + self.assertAllClose(x, verification_np, atol=atol) + self.assertShapeEqual(x, chol) + # Check that the cholesky is lower triangular, and has positive diagonal + # elements. + if chol_np.shape[-1] > 0: + chol_reshaped = np.reshape(chol_np, (-1, chol_np.shape[-2], + chol_np.shape[-1])) + for chol_matrix in chol_reshaped: + self.assertAllClose(chol_matrix, np.tril(chol_matrix), atol=atol) + self.assertTrue((np.diag(chol_matrix) > 0.0).all()) + + def _verifyCholesky(self, x, atol=1e-6): + # Verify that LL^T == x. + with self.test_session() as sess: + placeholder = array_ops.placeholder( + dtypes.as_dtype(x.dtype), shape=x.shape) + with self.test_scope(): + chol = linalg_ops.cholesky(placeholder) + verification = math_ops.matmul(chol, chol, adjoint_b=True) + self._verifyCholeskyBase(sess, placeholder, x, chol, verification, atol) + + def testBasic(self): + data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]) + for dtype in self.float_types: + self._verifyCholesky(data.astype(dtype)) + + def testBatch(self): + for dtype in self.float_types: + simple_array = np.array( + [[[1., 0.], [0., 5.]]], dtype=dtype) # shape (1, 2, 2) + self._verifyCholesky(simple_array) + self._verifyCholesky(np.vstack((simple_array, simple_array))) + odd_sized_array = np.array( + [[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]], dtype=dtype) + self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array))) + + # Generate random positive-definite matrices. + matrices = np.random.rand(10, 5, 5).astype(dtype) + for i in xrange(10): + matrices[i] = np.dot(matrices[i].T, matrices[i]) + self._verifyCholesky(matrices, atol=1e-4) + + def testNonSquareMatrix(self): + for dtype in self.float_types: + with self.assertRaises(ValueError): + linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]], dtype=dtype)) + with self.assertRaises(ValueError): + linalg_ops.cholesky( + np.array( + [[[1., 2., 3.], [3., 4., 5.]], [[1., 2., 3.], [3., 4., 5.]]], + dtype=dtype)) + + def testWrongDimensions(self): + for dtype in self.float_types: + tensor3 = constant_op.constant([1., 2.], dtype=dtype) + with self.assertRaises(ValueError): + linalg_ops.cholesky(tensor3) + with self.assertRaises(ValueError): + linalg_ops.cholesky(tensor3) + + @unittest.skip("Test is slow") + def testLarge(self): + n = 200 + shape = (n, n) + data = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag( + np.ones(n).astype(np.float32)) + self._verifyCholesky(data, atol=1e-4) + + def testMatrixConditionNumbers(self): + for dtype in self.float_types: + condition_number = 1000 + size = 20 + + # Generate random positive-definite symmetric matrices, and take their + # Eigendecomposition. + matrix = np.random.rand(size, size) + matrix = np.dot(matrix.T, matrix) + _, w = np.linalg.eigh(matrix) + + # Build new Eigenvalues exponentially distributed between 1 and + # 1/condition_number + v = np.exp(-np.log(condition_number) * np.linspace(0, size, size) / size) + matrix = np.dot(np.dot(w, np.diag(v)), w.T).astype(dtype) + self._verifyCholesky(matrix, atol=1e-4) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index cbe2888696c87c6c2f50c3de71e8531977ea395a..11d8a99ffe1a136a54b16e20f1792062203f7969 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -24,10 +24,12 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest +@test_util.with_c_api class FunctionTest(XLATestCase): def testFunction(self): diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index a773b5a94742062511bc8bdc6a202b513ce98db3..a80d69fa5f5099b8a8b67df0da9c92b957e9d194 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -76,7 +76,8 @@ class FusedBatchNormTest(XLATestCase): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") - offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") + offset = array_ops.placeholder( + np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format) @@ -112,7 +113,8 @@ class FusedBatchNormTest(XLATestCase): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") - offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") + offset = array_ops.placeholder( + np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y, mean, var = nn.fused_batch_norm( t_val, @@ -153,7 +155,7 @@ class FusedBatchNormTest(XLATestCase): def testLearningWithGradientChecker(self): self._testLearning(True) - def testGradient(self): + def testGradientTraining(self): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. channel = 3 @@ -173,7 +175,7 @@ class FusedBatchNormTest(XLATestCase): var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC") + grad, x, scale, mean, var, data_format="NHWC", is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { @@ -191,6 +193,53 @@ class FusedBatchNormTest(XLATestCase): self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) + def testGradientInference(self): + # TODO(b/64270657): Use gradient_checker here in addition to comparing with + # this reference implementation. + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] + grad_val = np.random.random_sample(x_shape).astype(np.float32) + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + mean_val = np.random.random_sample(scale_shape).astype(np.float32) + var_val = np.random.random_sample(scale_shape).astype(np.float32) + + with self.test_session() as sess, self.test_scope(): + grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") + x = array_ops.placeholder(np.float32, shape=x_shape, name="x") + mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") + var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") + scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") + with self.test_scope(): + out = gen_nn_ops.fused_batch_norm_grad( + grad, x, scale, mean, var, data_format="NHWC", is_training=False) + grad_x, grad_scale, grad_offset, _, _ = out + + ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( + grad, x, scale, mean, var, data_format="NHWC", is_training=False) + + grad_x_val, grad_scale_val, grad_offset_val, = sess.run( + [grad_x, grad_scale, grad_offset], { + grad: grad_val, + x: x_val, + mean: mean_val, + var: var_val, + scale: scale_val + }) + grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run( + [ref_x, ref_scale, ref_offset], { + grad: grad_val, + x: x_val, + mean: mean_val, + var: var_val, + scale: scale_val + }) + + self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) + self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) + self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a04f376ebf6092fd9b6e879796454b1a5c648c96 --- /dev/null +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -0,0 +1,142 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for image ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_image_ops +from tensorflow.python.platform import test + + +class ResizeBilinearTest(XLATestCase): + + def _assertForwardOpMatchesExpected(self, + image_np, + target_shape, + expected=None): + if expected is None: + self.fail("expected must be specified") + with self.test_session() as sess, self.test_scope(): + image = array_ops.placeholder(image_np.dtype) + resized = gen_image_ops.resize_bilinear( + image, target_shape, align_corners=True) + out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def _assertBackwardOpMatchesExpected(self, + grads_np, + input_shape=None, + dtype=None, + expected=None): + if input_shape is None: + self.fail("input_shape must be specified") + if expected is None: + self.fail("expected must be specified") + with self.test_session() as sess, self.test_scope(): + dtype = dtype or np.float32 + grads = array_ops.placeholder(np.float32) + resized = gen_image_ops._resize_bilinear_grad( + grads, + np.zeros([1, input_shape[0], input_shape[1], 1], dtype=dtype), + align_corners=True) + out = sess.run(resized, {grads: grads_np[np.newaxis, :, :, np.newaxis]}) + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def testAlignCorners1x2To3x2(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2]], dtype=dtype), [3, 3], + expected=np.array( + [[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32)) + + def testAlignCorners1x2To3x2Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), + input_shape=[1, 2], + dtype=dtype, + expected=np.array([[9, 12]], dtype=np.float32)) + + def testAlignCorners2x2To1x1(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=dtype), [1, 1], + expected=np.array([[1]], dtype=np.float32)) + + def testAlignCorners2x2To1x1Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[7]], dtype=np.float32), + input_shape=[2, 2], + dtype=dtype, + expected=np.array([[7, 0], [0, 0]], dtype=np.float32)) + + def testAlignCorners2x2To3x3(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=dtype), [3, 3], + expected=np.array( + [[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]], dtype=np.float32)) + + def testAlignCorners2x2To3x3Grad(self): + self._assertBackwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), + input_shape=[2, 2], + expected=np.array([[5.25, 8.25], [14.25, 17.25]], dtype=np.float32)) + + def testAlignCorners3x3To2x2(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), [2, 2], + expected=np.array([[1, 3], [7, 9]], dtype=np.float32)) + + def testAlignCorners3x3To2x2Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[7, 13], [22, 4]], dtype=np.float32), + input_shape=[3, 3], + dtype=dtype, + expected=np.array( + [[7, 0, 13], [0, 0, 0], [22, 0, 4]], dtype=np.float32)) + + def testAlignCorners4x4To3x3(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=dtype), [3, 3], + expected=np.array( + [[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]], dtype=np.float32)) + + def testAlignCorners4x4To3x3Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), + input_shape=[4, 4], + dtype=dtype, + expected=np.array( + [[1, 1, 1, 3], [2, 1.25, 1.25, 3], [2, 1.25, 1.25, 3], + [7, 4, 4, 9]], + dtype=np.float32)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 6a8c3bcd55a6e454a19b6249cf4eb48739c8657f..798daaadbc5be50ef9cf7e1205f6d5a0bde59640 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -2460,6 +2460,36 @@ TEST_F(OpTest, Reshape) { }); } +TEST_F(OpTest, ResizeBilinear) { + Repeatedly([this]() { + std::vector in_dims = RandomDims(4, 4); + std::vector out_dims = RandomDims(2, 2); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ResizeBilinear") + .RandomInput(DT_FLOAT, in_dims) + .Input(test::AsTensor( + std::vector(out_dims.begin(), out_dims.end()))) + .Attr("T", DT_FLOAT) + .Attr("align_corners", true)); + }); +} + +TEST_F(OpTest, ResizeBilinearGrad) { + Repeatedly([this]() { + std::vector in_dims = RandomDims(4, 4); + std::vector out_dims = RandomDims(2, 2); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ResizeBilinearGrad") + .RandomInput(DT_FLOAT, in_dims) + .RandomInput(DT_FLOAT, + {in_dims[0], out_dims[0], out_dims[1], in_dims[3]}) + .Attr("T", DT_FLOAT) + .Attr("align_corners", true)); + }); +} + TEST_F(OpTest, Reverse) { Repeatedly([this]() { std::vector dims = RandomDims(1); diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index efda2cc207b2ab56774d193117a2237f3afbfb55..965fdf684b973498d0b3c3cde17711cca7279705 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -67,25 +67,37 @@ class ReduceOpsTest(XLATestCase): np.arange(-10, -4).reshape(2, 3), np.arange(-4, 2).reshape(2, 3), ] - NONEMPTY_FLOAT_DATA = [ - np.arange(1, 7).reshape(2, 3), - np.arange(-10, -4).reshape(2, 3), - np.arange(-4, 2).reshape(2, 3), + COMPLEX_DATA = [ + np.zeros(shape=(2, 0)).astype(np.complex64), + np.zeros(shape=(0, 30)).astype(np.complex64), + np.arange(1, 13, dtype=np.float32).view(np.complex64).reshape(2, 3), + np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3), + np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3), ] + NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0] + NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0] BOOL_DATA = [ np.array([], dtype=np.bool).reshape(2, 0), np.array([], dtype=np.bool).reshape(0, 3), np.array([[False, True, False], [True, True, False]]), ] - def testReduceSum(self): + def testReduceSumF32(self): self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.FLOAT_DATA) - def testReduceProd(self): + def testReduceSumC64(self): + self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, + self.COMPLEX_DATA) + + def testReduceProdF32(self): self._testReduction(math_ops.reduce_prod, np.prod, np.float32, self.FLOAT_DATA) + def testReduceProdC64(self): + self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, + self.COMPLEX_DATA) + def testReduceMin(self): def reference_min(inp, axis): @@ -108,12 +120,16 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_max, reference_max, np.float32, self.FLOAT_DATA) - def testReduceMean(self): + def testReduceMeanF32(self): # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when # reducing across zero inputs. self._testReduction(math_ops.reduce_mean, np.mean, np.float32, self.NONEMPTY_FLOAT_DATA) + def testReduceMeanC64(self): + self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, + self.NONEMPTY_COMPLEX_DATA) + def testReduceAll(self): self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA) diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3260e63b23226d736a7ddc0f21a94a8c791e0442 --- /dev/null +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -0,0 +1,229 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for scan ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def numpy_reverse(x, axis): + length = len(x.shape) + if axis < 0: + axis = length + axis + + ix = [ + slice(None, None, -1) if i == axis else slice(None) for i in range(length) + ] + return x[ix] + + +def handle_options(func, x, axis, exclusive, reverse): + """Adds tf options to numpy scan ops.""" + length = len(x.shape) + if axis < 0: + axis = length + axis + + if reverse: + x = numpy_reverse(x, axis) + + if exclusive: + ix_head = [slice(0, 1) if i == axis else slice(None) for i in range(length)] + ix_init = [ + slice(0, -1) if i == axis else slice(None) for i in range(length) + ] + if func == np.cumsum: + init = np.zeros_like(x[ix_head]) + elif func == np.cumprod: + init = np.ones_like(x[ix_head]) + else: + raise ValueError("Unknown scan function.") + x = np.concatenate([init, func(x[ix_init], axis)], axis=axis) + else: + x = func(x, axis=axis) + + if reverse: + x = numpy_reverse(x, axis) + return x + + +class CumsumTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( + feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumsum(p, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 10).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 20).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumsum(input_tensor, [0]).eval() + + +class CumprodTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + prod = math_ops.cumprod(p, axis, exclusive, reverse) + tf_out = prod.eval(feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumprod(x, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 11).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 21).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumprod(input_tensor, [0]).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4336ebdbd184a081619f0a6951dd4514735c6eb6 --- /dev/null +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -0,0 +1,118 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for stateless random-number generation ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.contrib import stateless +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class StatelessRandomOpsTest(XLATestCase): + """Test cases for stateless random-number generator operators.""" + + def _random_types(self): + return [dtypes.float32] + + def testDeterminism(self): + # Stateless values should be equal iff the seeds are equal (roughly) + with self.test_session(), self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + for stateless_op in [ + stateless.stateless_random_uniform, stateless.stateless_random_normal + ]: + for shape in (), (3,), (2, 5): + for dtype in self._random_types(): + pure = stateless_op(shape, seed=seed_t, dtype=dtype) + values = [(seed, pure.eval(feed_dict={ + seed_t: seed + })) for seed in seeds] + for s0, v0 in values: + for s1, v1 in values: + self.assertEqual(s0 == s1, np.all(v0 == v1)) + + def testRandomUniformIsInRange(self): + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless.stateless_random_uniform( + shape=[1000], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue(np.all(y >= 0)) + self.assertTrue(np.all(y < 1)) + + def _chi_squared(self, x, bins): + """Pearson's Chi-squared test.""" + x = np.ravel(x) + n = len(x) + histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) + expected = n / float(bins) + return np.sum(np.square(histogram - expected) / expected) + + def testDistributionOfStatelessRandomUniform(self): + """Use Pearson's Chi-squared test to test for uniformity.""" + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 1000 + x = stateless.stateless_random_uniform( + shape=[n], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [565656, 121212]}) + # Tests that the values are distributed amongst 10 bins with equal + # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with + # p=0.05. This test is probabilistic and would be flaky if the random + # seed were not fixed. + self.assertTrue(self._chi_squared(y, 10) < 16.92) + + def _normal_cdf(self, x): + """Cumulative distribution function for a standard normal distribution.""" + return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) + + def _anderson_darling(self, x): + """Anderson-Darling test for a standard normal distribution.""" + x = np.sort(np.ravel(x)) + n = len(x) + i = np.linspace(1, n, n) + z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) + + (2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x))) + return -n - z / n + + def testDistributionOfStatelessRandomNormal(self): + """Use Anderson-Darling test to test distribution appears normal.""" + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 1000 + x = stateless.stateless_random_normal( + shape=[n], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [25252, 314159]}) + # The constant 2.492 is the 5% critical value for the Anderson-Darling + # test where the mean and variance are known. This test is probabilistic + # so to avoid flakiness the seed is fixed. + self.assertTrue(self._anderson_darling(y) < 2.492) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 76644380bdf2e0c24f6d363ddfaabdff836495d7..0da7442a24201011e3126e53c9d884534a0d721e 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -33,6 +33,17 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest +def nhwc_to_format(x, data_format): + """Converts a numpy array from NHWC format to `data_format`.""" + rank = len(x.shape) + if data_format == "NCHW": + return np.transpose(x, [0, rank - 1] + list(range(1, rank - 1))) + elif data_format == "NHWC": + return x + else: + raise ValueError("Unknown format {}".format(data_format)) + + class UnaryOpsTest(XLATestCase): """Test cases for unary operators.""" @@ -76,6 +87,12 @@ class UnaryOpsTest(XLATestCase): array_ops.diag_part, np.arange(36).reshape([2, 3, 2, 3]).astype(dtype), np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.diag, np.array([[1, 2], [3, 4]], dtype=dtype), + np.array( + [[[[1, 0], [0, 0]], [[0, 2], [0, 0]]], [[[0, 0], [3, 0]], + [[0, 0], [0, 4]]]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.identity, @@ -86,6 +103,21 @@ class UnaryOpsTest(XLATestCase): array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype), np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype), + np.array( + [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], + dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.matrix_diag, + np.array( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype), + np.array( + [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], + [[4, 0, 0], [0, 5, 0], [0, 0, 6]]], + [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], + [[10, 0, 0], [0, 11, 0], [0, 0, 12]]]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.matrix_diag_part, np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype), @@ -330,12 +362,22 @@ class UnaryOpsTest(XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - # TODO(b/65408531): math_ops.acosh (needs pow) - # TODO(b/65408531): math_ops.asinh (needs pow) # TODO(b/65408531): Wider support for log (needs atan2). atan2_supported = self.device == "XLA_GPU" if atan2_supported: + self._assertOpOutputMatchesExpected( + math_ops.acosh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arccosh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.asinh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arcsinh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( math_ops.atanh, np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), @@ -392,19 +434,26 @@ class UnaryOpsTest(XLATestCase): expected=np.log1p( np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) - # TODO(b/34703906): math_ops.rsqrt (needs pow) + val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.rsqrt, val, expected=1 / np.sqrt(val)) + + self._assertOpOutputMatchesExpected( + math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) - # TODO(b/34703906): math_ops.sigmoid (needs tanh) + self._assertOpOutputMatchesExpected( + math_ops.sqrt, val, expected=np.sqrt(val)) - # TODO(b/34703906): math_ops.sqrt (needs pow) + self._assertOpOutputMatchesExpected( + math_ops.tanh, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.tan, np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) - # TODO(b/34703906): math_ops.tanh (as itself) - ctypes = {np.complex64: np.float32} self._assertOpOutputMatchesExpected( math_ops.abs, @@ -624,55 +673,88 @@ class UnaryOpsTest(XLATestCase): equality_test=self.ListsAreClose) def testDepthToSpace(self): + def make_op(data_format): + def op(x): + return array_ops.depth_to_space(x, block_size=2, + data_format=data_format) + return op + for dtype in self.numeric_types: - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4]]]], dtype=dtype), - expected=np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype)) + for data_format in ["NCHW", "NHWC"]: + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), + data_format), + expected=nhwc_to_format(np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format( + np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], + dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], + dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype), - expected=np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format( + np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype), + data_format)) def testSpaceToDepth(self): + def make_op(data_format): + def op(x): + return array_ops.space_to_depth(x, block_size=2, + data_format=data_format) + return op + for dtype in self.numeric_types: - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4]]]], dtype=dtype)) + for data_format in ["NCHW", "NHWC"]: + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype), + data_format), + expected=nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], - dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], + dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype), + data_format)) def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index c50342dee45eba6ae54f01653ecc81ef096b547b..b08d6ab21e0746558cb3d4818d4c822c45d2e9ee 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -107,11 +107,26 @@ class VariableOpsTest(XLATestCase): [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]], ).astype(dtype), sess.run(x)) + def testShape(self): + for dtype in self.numeric_types: + init = np.ones([2, 3]).astype(dtype) + with self.test_session() as session, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + session.run(variables.variables_initializer([v])) + h = v.handle + s32, s64 = session.run([ + resource_variable_ops.variable_shape(h), + resource_variable_ops.variable_shape(h, out_type=dtypes.int64) + ]) + self.assertEqual(s32.dtype, np.int32) + self.assertEqual(s64.dtype, np.int64) + self.assertAllEqual(s32, [2, 3]) + self.assertAllEqual(s64, [2, 3]) + def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" for dtype in self.numeric_types: with self.test_session() as session: - print(ops.get_default_graph()) with self.test_scope(): with variable_scope.variable_scope("ascope", use_resource=True): x = variable_scope.get_variable( diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 912e819d8d63886c663aaabd3cbe3bd76a1ced07..5d1cb6d73570a1a3efbe0d2d37d9746bc0e2528f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,6 +1,6 @@ licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") package_group( name = "internal", @@ -25,6 +25,30 @@ package( load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +cc_library( + name = "tf2xla_supported_ops_lib", + srcs = ["tf2xla_supported_ops.cc"], + hdrs = ["tf2xla_supported_ops.h"], + visibility = ["//visibility:public"], + deps = [ + ":xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_binary( + name = "tf2xla_supported_ops", + srcs = ["tf2xla_supported_ops_main.cc"], + visibility = ["//visibility:public"], + deps = [":tf2xla_supported_ops_lib"], +) + xla_proto_library( name = "tf2xla_proto", srcs = ["tf2xla.proto"], @@ -67,7 +91,6 @@ cc_library( # Keep dependencies to a minimum here; this library is used in every AOT # binary produced by tfcompile. "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", ], @@ -125,6 +148,7 @@ cc_library( ":functionalize_control_flow", ":sharding_util", ":tf2xla_util", + "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -178,7 +202,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -215,6 +238,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], @@ -357,13 +381,6 @@ tf_cc_test( ], ) -cc_library( - name = "xla_local_runtime_context", - hdrs = ["xla_local_runtime_context.h"], - visibility = ["//visibility:public"], - deps = ["//tensorflow/core:framework_lite"], -) - cc_library( name = "dump_graph", srcs = [ @@ -400,6 +417,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index d57273d84442c17565a6ace1c29170a0f3ba583b..ab2f1e9a7ab577bbe704e568b21d9912439605ca 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -52,6 +52,8 @@ Status BackwardsConstAnalysis(const Graph& g, {"Conv2DBackpropInput", "input_sizes"}, {"Conv3DBackpropFilterV2", "filter_sizes"}, {"Conv3DBackpropInputV2", "input_sizes"}, + {"Cumprod", "axis"}, + {"Cumsum", "axis"}, {"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"}, {"DepthwiseConv2dNativeBackpropInput", "input_sizes"}, {"DynamicStitch", "indices"}, @@ -78,6 +80,7 @@ Status BackwardsConstAnalysis(const Graph& g, {"Range", "limit"}, {"Range", "delta"}, {"Reshape", "shape"}, + {"ResizeBilinear", "size"}, {"ResourceStridedSliceAssign", "begin"}, {"ResourceStridedSliceAssign", "end"}, {"ResourceStridedSliceAssign", "strides"}, diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index ddd912b87315f7943915153b5bf73531107af54d..03603ee9baefd1d20d220faf63c9c1c427ebdf31 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -63,7 +63,12 @@ string MakeUniquePath(string name) { string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, graph_def)); + Status status = WriteTextProto(Env::Default(), path, graph_def); + if (!status.ok()) { + VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status; + path.clear(); + path = "(unavailable)"; + } return path; } @@ -79,7 +84,13 @@ string DumpGraphToFile(const string& name, Graph const& graph, string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, fdef)); + Status status = WriteTextProto(Env::Default(), path, fdef); + if (!status.ok()) { + VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : " + << status; + path.clear(); + path = "(unavailable)"; + } return path; } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 6ef4860f35835e59be3452b57204d42c82d0816b..dd67a1dea9656bac9cf3eaa09295a0d42e283706 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -528,251 +529,101 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, class FunctionalizeCond { public: - // Identifies the connected parts of the tf.Cond. - struct ClusterHandle { - explicit ClusterHandle(int representative = -1) - : representative(representative) {} - - bool operator==(const ClusterHandle& other) const { - return representative == other.representative; - } - - bool operator!=(const ClusterHandle& other) const { - return !(*this == other); - } + // All nodes are assumed to be either in no branch, then branch, else branch, + // or both branches (such as merge nodes). + enum Branch { + kElseBranch = 0, + kThenBranch = 1, + kBoth = 2, + kNeither = 3, + kNumBranchTypes = 4 + }; - bool operator<(const ClusterHandle& other) const { - return representative < other.representative; + // Returns a textual representation of the Branch b. + static string Branch_Name(FunctionalizeCond::Branch b); + + // Comparison function used for sorting nodes consistently. + struct CondCmp { + bool operator()(const Node* lhs, const Node* rhs) const { + bool lhs_is_resource = + lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; + bool rhs_is_resource = + rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; + return std::tie(lhs_is_resource, lhs->name()) < + std::tie(rhs_is_resource, rhs->name()); } + }; - bool operator>(const ClusterHandle& other) const { - return representative > other.representative; - } + // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf + // nodes. That is, attempt to transform every remaining switch and merge nodes + // in the graph into XlaIf nodes. + // Precondition: All while loops have been removed from graph. + static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); + private: + struct ForwardFlowNode { + explicit ForwardFlowNode(Branch branch = Branch::kNeither) + : branch(branch), count(0) {} string ToString() const { - return strings::StrCat("Cluster_", representative); + return strings::StrCat("branch=", Branch_Name(branch), " count=", count); } - - // Vector of UnionFind indexable by ClusterHandle and Node*. - struct Vector { - explicit Vector(size_t size) : clusters(size) {} - - UnionFind& at(const ClusterHandle& cluster) { - return clusters.at(cluster.representative); - } - - UnionFind& at(const Node* node) { - return clusters.at(node->id()); - } - - UnionFind& operator[](const Node* node) { - return clusters.at(node->id()); - } - - size_t size() const { return clusters.size(); } - - void resize(size_t count) { return clusters.resize(count); } - - private: - std::vector> clusters; - }; - - private: - int representative; - }; - - // Represents a node in the clustered graph consisting of switch_nodes, - // merge_nodes as well as the edges into and out of this node to other - // Clusters. Each Cluster corresponds to a ClusterHandle and has a - // corresponding representative. - struct Cluster { - std::unordered_set switch_nodes; - std::unordered_set merge_nodes; - std::unordered_set in_nodes; - std::unordered_set out_nodes; - - // A member of the ClusterHandle corresponding to this Cluster. - ClusterHandle representative; - bool visited = false; + Branch branch; + int count; }; - // Represent the clustered graph as map from cluster representative to - // Cluster. - using ClusteredGraph = std::map; - - // The arguments and condition of a XlaIf. The arguments are ordered by node - // id in the original graph. - struct CondArgs { - struct CondCmp { - bool operator()(const Node* lhs, const Node* rhs) const { - bool lhs_is_resource = - lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; - bool rhs_is_resource = - rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; - return std::tie(lhs_is_resource, lhs->name()) < - std::tie(rhs_is_resource, rhs->name()); - } - }; - Node* conditional = nullptr; - std::set args; - }; - - static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); - - private: FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) - : clusters_(graph->num_node_ids()), library_(library), graph_(graph) {} - - // Returns a vector of Merge nodes from the clustered graph where the nodes - // are sorted by the number of switch nodes minus number of merge nodes - // from a root of the clustered graph to the given Merge node, with ties - // broken by the representative of the Cluster. - std::vector> SortedMergeNodes(); - - // Returns whether the graph has no conditionals. - bool NoConditionals() const { return merge_nodes_.empty(); } - - // Construct the clustered graph by creating nodes for each cluster and the - // connections between the clusters. Switch and Merge nodes partition - // clusters, so iterate over those. Note: a Cluster may have neither a - // Merge or Switch but will have an in/out edge from a Cluster that has. - void CreateClusters(); - - // Creates the clustered graph by identifying all the edges between different - // clusters and collecting all switch and merge nodes that correspond to a - // cluster. - void CreateClusteredGraph(); - - // If `from` and `to` correspond to different clusters, then merge the nodes - // in the clustered graph corresponding to `from` and `to`. - // - // If `remove_from_graph` is specified then the `from` node is also removed - // from the clustered graph post contracting the edge. - void ContractEdge(Cluster* from, Cluster* to, bool remove_from_graph = false); + : library_(library), graph_(graph) {} + + // Perform the actual cond functionalization. Iterate over groups of switch + // nodes (linked by common predicate), from innermost to outermost, and + // extract into XlaIf nodes. + Status FunctionalizeInternal(); // Converts a Merge node to a XlaIf. This encapsulates the process of // extracting the bodies needed for the then and else branch, creates a XlaIf // node, removing the nodes of the branches from the graph and replacing the // merge node with a XlaIf. - Status ConvertMergeToXlaIf(Cluster* merge_cluster); - - // Removes a Switch cluster feeding directly into a Merge cluster by removing - // the Switch and Merge nodes and collapsing into a single cluster. - Status RemoveTrivialMerge(Cluster* merge_cluster); - - // Returns the switch cluster corresponding to the merge node. This function - // only returns the switch cluster in the simple case where we have a switch - // node is the entry of a diamond corresponding to a conditional: - // - // Switch - // / \ - // Branch Branch - // \ / - // merge_cluster - // - // Note: either of the branches may be empty. The case where both branches are - // empty is handled by RemoveTrivialMerge. - gtl::optional GetSwitchCluster(const Cluster& merge_cluster); - - // Determines the arguments needed as input to the Merge cluster originating - // from the Switch cluster. - xla::StatusOr DetermineCondArgs(const Cluster& merge_cluster, - const Cluster& switch_cluster); - - // Builds a XlaIfOp to replace the Merge node with. - xla::StatusOr BuildAndAddXlaIfOp(const CondArgs& cond_args, - const Cluster& merge_cluster, - const std::vector& outputs); + Status ConvertCorrespondingMergeToXlaIf( + const std::vector& switch_nodes, + const std::vector& merge_nodes, Node* predicate); + + // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with. + xla::StatusOr BuildAndAddXlaIfOp( + const std::vector& switch_nodes, + const std::vector& merge_nodes, Node* predicate); // Extracts a function body corresponding to the given input edge of the merge // node. - Status ExtractBody(const CondArgs& cond_args, const Cluster& merge_cluster, - const std::vector& outputs, int input_edge, + Status ExtractBody(const std::vector& switch_nodes, + const std::vector& merge_nodes, int input_edge, Graph* body); // Adds all the input edges to `if_node` corresponding to the arguments. - Status AddInputEdges(const CondArgs& cond_args, Node* if_node); + Status AddInputEdges(const std::vector& cond_args, Node* predicate, + Node* if_node); // Adds all output edges from the `if_node`. Status AddOutputEdges(const std::vector& outputs, Node* if_node); - // Removes all nodes from the graph that are part of cluster. - void RemoveClusterNodes(Cluster* cluster); + // Returns the switches of graph_ in postorder. Dead switch nodes are skipped + // and removed from the graph. + std::vector DetermineSwitchOrder(); - // Removes all argument nodes that are unused. - template - void RemoveUnusedArgs(const T& args); + // Update the state for destination based on the state of source and the node + // being updated. + Status Join(const ForwardFlowNode& src_state, const Node* dst, + ForwardFlowNode* dst_state); - // Removes all Merge nodes in merge_cluster. - void RemoveMergeNodes(Cluster* merge_cluster); + // Validates that the branch_map and frontier of nodes for the conditional + // section are as expected. + Status ValidBranchMapAndFrontier( + const std::unordered_map& branch_map, + const std::unordered_set& frontier); - // Returns the representative member of the corresponding cluster. - ClusterHandle Representative(const Node* node) { - return clusters_.at(node).Get(); - } - - ClusteredGraph clustered_graph_; - ClusterHandle::Vector clusters_; - std::unordered_set merge_nodes_; - std::unordered_set switch_nodes_; FunctionLibraryDefinition* library_; Graph* graph_; }; -std::ostream& operator<<(std::ostream& os, - const FunctionalizeCond::ClusterHandle& c) { - os << c.ToString(); - return os; -} - -// Returns a dot representation of the clustered graph showing the connections -// between the nodes and the nodes in each cluster. -string DebugString(const Graph& graph, - FunctionalizeCond::ClusterHandle::Vector* clusters) { - string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n"; - std::map subgraphs; - for (Node* n : graph.nodes()) { - if (n->IsOp()) { - strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), - " [label=\"", n->name(), "\"];\n"); - } - } - for (auto kv : subgraphs) { - strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n", - "style=filled; color=lightgrey;", "label = \"", - kv.first.ToString(), "\";\n", kv.second, "}\n"); - } - for (Node* n : graph.nodes()) { - if (!n->IsOp()) { - continue; - } - for (Node* in : n->in_nodes()) { - if (in->IsOp()) { - strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); - } - } - } - return strings::StrCat(ret, "}"); -} - -string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { - string ret = "digraph {\ncompound=true;labeljust=\"r\";\n"; - auto name = [](const FunctionalizeCond::Cluster& cluster) { - return cluster.representative.ToString(); - }; - for (auto kv : clustered_graph) { - strings::StrAppend(&ret, kv.first.ToString(), " [label=\"", name(kv.second), - " (", kv.second.switch_nodes.size(), ", ", - kv.second.merge_nodes.size(), ")\"];\n"); - } - for (auto kv : clustered_graph) { - for (auto in : kv.second.in_nodes) { - strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n"); - } - } - return strings::StrCat(ret, "}"); -} - bool IsDeadSwitch(const Node* node) { for (const Edge* e : node->out_edges()) { const Node* dst = e->dst(); @@ -788,241 +639,212 @@ bool IsDeadSwitch(const Node* node) { return true; } -void FunctionalizeCond::CreateClusters() { - for (Node* node : graph_->nodes()) { - if (!node->IsOp()) { - continue; +string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) { + const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = { + "else", "then", "both", "neither", "count"}; + return branch_name[b]; +} + +Status FunctionalizeCond::ValidBranchMapAndFrontier( + const std::unordered_map& + branch_map, + const std::unordered_set& frontier) { + std::unordered_set pending[kNumBranchTypes]; + for (const auto& kv : branch_map) { + if (kv.second.count != kv.first->in_edges().size()) { + return errors::FailedPrecondition("Value ", kv.first->DebugString(), + " not dominated by switch nodes."); } - if (IsSwitch(node)) { - switch_nodes_.insert(node); - } else if (IsMerge(node)) { - merge_nodes_.insert(node); + if (VLOG_IS_ON(1)) { + // Append attribute to the graph if running with logging to make the + // changes clearer in the visualization. + kv.first->AddAttr("_XlaFunctionalizeBranch", + Branch_Name(kv.second.branch)); } - ClusterHandle& cluster = clusters_.at(node).Get(); - cluster = ClusterHandle(node->id()); } - - // If there are no Merge nodes, then terminate. - if (merge_nodes_.empty()) { - return; + for (Node* n : frontier) { + pending[branch_map.at(n).branch].insert(n); } - - // Remove all dead Switch nodes. - RemoveUnusedArgs(switch_nodes_); - - // All parent_'s are still nullptr so clusters_ may still be resized. Resize - // conservatively assuming all merge nodes become XlaIf nodes. - clusters_.resize(clusters_.size() + merge_nodes_.size()); - - // Merge a cluster with its input, unless the input is a Switch node or - // the node is a Merge node. - for (const Node* node : graph_->nodes()) { - if (IsMerge(node) || IsSwitch(node) || !node->IsOp()) { - continue; - } - for (const Node* in : node->in_nodes()) { - if (in->IsOp() && !IsSwitch(in) && !IsMerge(in)) { - clusters_.at(node).Merge(&clusters_.at(in)); - } + TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]); + for (const Node* n : pending[kBoth]) { + TF_RET_CHECK(IsMerge(n)) << n->DebugString(); + // Merge nodes may be in then or else branch too + } + int index = (pending[kThenBranch].size() <= pending[kElseBranch].size()) + ? kThenBranch + : kElseBranch; + int other = 1 - index; + for (const Node* n : pending[index]) { + if (pending[other].find(n) != pending[other].end()) { + return errors::Internal( + "Node (", n->DebugString().c_str(), + ") in both Else and Then branch should be in Both."); } } + return Status::OK(); } -void FunctionalizeCond::ContractEdge(Cluster* from, Cluster* to, - bool remove_from_graph) { - VLOG(3) << "ContractEdge from = " << from->representative - << " to = " << to->representative; - if (from->representative == to->representative) { - return; - } - to->merge_nodes.insert(from->merge_nodes.begin(), from->merge_nodes.end()); - from->merge_nodes.clear(); - to->switch_nodes.insert(from->switch_nodes.begin(), from->switch_nodes.end()); - from->switch_nodes.clear(); - - for (Cluster* from_out : from->out_nodes) { - from_out->in_nodes.erase(from); - if (from_out->representative != to->representative) { - from_out->in_nodes.insert(to); - to->out_nodes.insert(from_out); +Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, + const Node* dst, ForwardFlowNode* dst_state) { + TF_RET_CHECK(dst_state->branch != Branch::kBoth && + dst_state->branch != Branch::kNumBranchTypes) + << "Unexpected/Invalid branch type: Merging " + << Branch_Name(src_state.branch) << " with " + << Branch_Name(dst_state->branch); + if (dst_state->branch == Branch::kNeither) { + dst_state->branch = src_state.branch; + } else if (src_state.branch != dst_state->branch && + src_state.branch != Branch::kNeither) { + if (IsMerge(dst)) { + dst_state->branch = Branch::kBoth; + } else { + return errors::Internal("Illegal merge: ", src_state.ToString(), " with ", + dst_state->ToString(), " for ", + dst->DebugString()); } } - from->out_nodes.clear(); + ++dst_state->count; + return Status::OK(); +} - for (Cluster* from_in : from->in_nodes) { - from_in->out_nodes.erase(from); - if (from_in->representative != to->representative) { - from_in->out_nodes.insert(to); - to->in_nodes.insert(from_in); +std::vector FunctionalizeCond::DetermineSwitchOrder() { + std::vector dead_switches; + std::vector switch_order; + DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) { + if (IsSwitch(n)) { + if (IsDeadSwitch(n)) { + dead_switches.push_back(n); + } else { + switch_order.push_back(n); + } } + }); + + // Remove all dead switch nodes. + for (Node* n : dead_switches) { + graph_->RemoveNode(n); } - from->in_nodes.clear(); - to->in_nodes.erase(from); - to->out_nodes.erase(from); - clusters_.at(to->representative).Merge(&clusters_.at(from->representative)); - from->visited = true; + return switch_order; +} - if (remove_from_graph) { - clustered_graph_.erase(from->representative); +Status FunctionalizeCond::FunctionalizeInternal() { + std::vector switch_order = DetermineSwitchOrder(); + // If there are no switch nodes, then terminate. + if (switch_order.empty()) { + return Status::OK(); } -} -void FunctionalizeCond::CreateClusteredGraph() { - auto update_cluster_for_node = [this](Node* node) -> Cluster& { - ClusterHandle repr = Representative(node); - Cluster& cluster_node = clustered_graph_[repr]; - cluster_node.representative = repr; - for (const Node* in : node->in_nodes()) { - ClusterHandle other_repr = Representative(in); - // Skip source, sink and internal edges. - if (!in->IsOp() || other_repr == repr) { - continue; - } - Cluster& cluster_node_in = clustered_graph_[other_repr]; - cluster_node.in_nodes.insert(&cluster_node_in); - cluster_node_in.out_nodes.insert(&cluster_node); - cluster_node_in.representative = other_repr; - } - for (const Node* out : node->out_nodes()) { - ClusterHandle other_repr = Representative(out); - // Skip source, sink and internal edges. - if (!out->IsOp() || other_repr == repr) { - continue; - } - Cluster& cluster_node_out = clustered_graph_[other_repr]; - cluster_node.out_nodes.insert(&cluster_node_out); - cluster_node_out.in_nodes.insert(&cluster_node); - cluster_node_out.representative = other_repr; - } - return cluster_node; + struct PredicateSwitches { + explicit PredicateSwitches(Node* predicate) : predicate(predicate) {} + + Node* predicate; + std::vector switches; }; - for (Node* node : switch_nodes_) { - update_cluster_for_node(node).switch_nodes.insert(node); - } - for (Node* node : merge_nodes_) { - update_cluster_for_node(node).merge_nodes.insert(node); - } // Merge Switch nodes with common predicate. - std::unordered_map> predicate_to_switch; - for (Node* node : switch_nodes_) { - Node* tmp; - TF_CHECK_OK(node->input_node(1, &tmp)); - predicate_to_switch[tmp].push_back(node); - } - for (auto kv : predicate_to_switch) { - Cluster& first = clustered_graph_.at(Representative(kv.second.front())); - for (Node* switch_node : kv.second) { - ClusterHandle handle = Representative(switch_node); - Cluster& cluster = clustered_graph_.at(handle); - ContractEdge(&cluster, &first, /*remove_from_graph=*/true); + std::vector predicate_switch_order; + std::unordered_map predicate_index; + // The nodes in switch_order are in reverse topological order, but the + // clustered switches need not be (i.e., when considered as a cluster one + // element of a cluster may be later in the topological order than another + // node whose cluster is later in the topological order of clustered + // switches). + for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { + Node* pred; + TF_CHECK_OK((*it)->input_node(1, &pred)); + if (predicate_index.find(pred) == predicate_index.end()) { + predicate_index[pred] = predicate_switch_order.size(); + predicate_switch_order.emplace_back(pred); } + predicate_switch_order[predicate_index[pred]].switches.push_back(*it); } - // Merge Merge nodes with common input together. - for (Node* node : merge_nodes_) { - Cluster& cluster = clustered_graph_.at(Representative(node)); - for (const Node* in : node->in_nodes()) { - if (!in->IsOp()) { - continue; - } - Cluster& cluster_node_in = clustered_graph_.at(Representative(in)); - // ContractEdge can modify out_nodes of cluster_node_in, so traverse - // over out_nodes assuming it does. - for (auto it = cluster_node_in.out_nodes.begin(); - it != cluster_node_in.out_nodes.end();) { - if (!(*it)->merge_nodes.empty()) { - ContractEdge(*it++, &cluster, /*remove_from_graph=*/true); - } else { - ++it; - } - } - } - } + // Iterate from innermost set of clustered switches to outermost, replacing + // matching switch->merge subgraphs with single XlaIf nodes. + for (auto it = predicate_switch_order.rbegin(); + it != predicate_switch_order.rend(); ++it) { + auto& ps = *it; + VLOG(3) << "Flow down from: " << ps.predicate->name() << " -> " + << NodesToString(ps.switches); - VLOG(3) << "Graph with clusters: " << DebugString(*graph_, &clusters_); - VLOG(3) << "ClusteredGraph: " << DebugString(clustered_graph_); -} + std::unordered_map branch_map; + std::unordered_set frontier; -gtl::optional FunctionalizeCond::GetSwitchCluster( - const Cluster& merge_cluster) { - VLOG(3) << "GetSwitchCluster for " << merge_cluster.representative; - gtl::optional switch_cluster; - if (merge_cluster.in_nodes.size() > 2) { - return gtl::nullopt; - } - for (Cluster* in : merge_cluster.in_nodes) { - Cluster* cluster = in; - if (in->switch_nodes.empty()) { - if (in->in_nodes.size() != 1) { - return gtl::nullopt; - } - // There is only a single `in` cluster. - cluster = *in->in_nodes.begin(); - } - if (cluster->switch_nodes.empty()) { - return gtl::nullopt; - } + std::vector stack = ps.switches; + std::vector visited(graph_->num_node_ids(), false); + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); - if (switch_cluster.has_value() && *switch_cluster != cluster) { - return gtl::nullopt; - } else { - switch_cluster = cluster; - } - } - return switch_cluster; -} + if (visited[n->id()]) { + continue; + } + visited[n->id()] = true; -xla::StatusOr FunctionalizeCond::DetermineCondArgs( - const Cluster& merge_cluster, const Cluster& switch_cluster) { - VLOG(2) << "DetermineCondArgs for " << merge_cluster.representative - << " with switch cluster " << switch_cluster.representative; - CondArgs ret; - auto feeds_into_branch_cluster = [&](Node* switch_cluster) { - for (Node* out : switch_cluster->out_nodes()) { - ClusterHandle repr = Representative(out); - if (repr == merge_cluster.representative) { - return true; + // Propagate branch state along each edge of a switch node. + bool sink_only = true; + for (const Edge* e : n->out_edges()) { + Node* out = e->dst(); + if (!out->IsOp()) { + continue; + } + sink_only = false; + // Propagate branch information. + ForwardFlowNode& ffn = branch_map[out]; + if (IsSwitch(n)) { + int index = e->IsControlEdge() ? Branch::kNeither : e->src_output(); + TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn)); + } else { + TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn)); + } + if (IsMerge(out)) { + if (out->in_edges().size() == ffn.count) { + frontier.insert(out); + } + } else if (!visited[out->id()] && ffn.count == out->in_edges().size()) { + // If all predecessors are dominated by the switch nodes, then add + // the output to the stack. + stack.push_back(out); + } } - for (Cluster* in : merge_cluster.in_nodes) { - if (repr == in->representative) { - return true; + if (sink_only) { + if (!IsIdentity(n)) { + VLOG(1) << "Feeding into sink: " << n->DebugString(); } } } - return false; - }; - for (Node* switch_cluster_node : switch_cluster.switch_nodes) { - if (!feeds_into_branch_cluster(switch_cluster_node)) { - continue; - } - Node* tmp; - TF_RETURN_IF_ERROR(switch_cluster_node->input_node(1, &tmp)); - if (ret.conditional == nullptr) { - ret.conditional = tmp; - } else if (ret.conditional != tmp) { - return errors::Unimplemented( - "Switch statements with different conditionals cannot be " - "converted into functional conditional."); + TF_RETURN_IF_ERROR(ValidBranchMapAndFrontier(branch_map, frontier)); + VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_bc", *graph_); + std::vector switch_nodes(ps.switches); + std::sort(switch_nodes.begin(), switch_nodes.end(), CondCmp()); + std::vector merge_nodes(frontier.begin(), frontier.end()); + std::sort(merge_nodes.begin(), merge_nodes.end(), CondCmp()); + TF_RETURN_IF_ERROR(ConvertCorrespondingMergeToXlaIf( + switch_nodes, merge_nodes, ps.predicate)); + for (auto& del_kv : branch_map) { + graph_->RemoveNode(del_kv.first); + } + for (Node* node : switch_nodes) { + graph_->RemoveNode(node); } - ret.args.insert(switch_cluster_node); + VLOG(2) << "FunctionalizeControlFlow (after XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_ac", *graph_); } - return ret; + return Status::OK(); } xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( - const CondArgs& cond_args, const Cluster& merge_cluster, - const std::vector& outputs) { - VLOG(2) << "Build if op for " << NodesToString(merge_cluster.merge_nodes) - << " with input " << NodesToString(cond_args.args); + const std::vector& switch_nodes, + const std::vector& merge_nodes, Node* predicate) { + VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input " + << NodesToString(switch_nodes); NodeDef if_def; // Create a new If node using the name of the merge node. - NodeDefBuilder builder( - strings::StrCat((*merge_cluster.merge_nodes.begin())->name(), "_If"), - "XlaIf"); + NodeDefBuilder builder(strings::StrCat(predicate->name(), "_If"), "XlaIf"); string branch[] = {"else_branch", "then_branch"}; for (int i = 0; i < 2; ++i) { static std::atomic sequence_num(0LL); @@ -1032,8 +854,7 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( body_name.set_name( strings::StrCat("_functionalize_if_", branch[i], "_", id)); auto body = xla::MakeUnique(graph_->op_registry()); - TF_RETURN_IF_ERROR( - ExtractBody(cond_args, merge_cluster, outputs, i, body.get())); + TF_RETURN_IF_ERROR(ExtractBody(switch_nodes, merge_nodes, i, body.get())); VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); FunctionDef body_fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); @@ -1044,7 +865,7 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( // Build input type. std::vector inputs; DataTypeVector in_arg_types; - for (const Node* arg : cond_args.args) { + for (const Node* arg : switch_nodes) { const Edge* in_edge; TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); if (in_edge->IsControlEdge()) { @@ -1060,17 +881,17 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( // Build output type. DataTypeVector out_type; - for (const Node* merge : merge_cluster.merge_nodes) { + for (const Node* merge : merge_nodes) { DataType dtype = merge->output_type(0); out_type.push_back(dtype); } builder.Attr("Tout", out_type); builder.Attr("Tcond", DT_BOOL); - builder.Device(cond_args.conditional->assigned_device_name()); + builder.Device(predicate->assigned_device_name()); // Conditional should be the first input ... - builder.Input(NodeDefBuilder::NodeOut(cond_args.conditional->name(), 0, - cond_args.conditional->output_type(0))); + builder.Input( + NodeDefBuilder::NodeOut(predicate->name(), 0, predicate->output_type(0))); // ... followed by the other inputs. builder.Input(inputs); @@ -1079,53 +900,15 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( return if_node; } -void FunctionalizeCond::RemoveClusterNodes(Cluster* cluster) { - VLOG(3) << "RemoveClusterNodes for " << cluster->representative; - ClusterHandle repr = cluster->representative; - std::deque to_delete; - for (Node* node : graph_->nodes()) { - if (Representative(node) == repr) { - to_delete.push_back(node); - } - } - for (Node* n : to_delete) { - graph_->RemoveNode(n); - } -} - -template -void FunctionalizeCond::RemoveUnusedArgs(const T& args) { - VLOG(2) << "RemoveUnusedArgs among: " << NodesToString(args); - - std::deque to_delete; - for (Node* arg : args) { - if (IsDeadSwitch(arg)) { - to_delete.push_back(arg); - for (Node* n : arg->out_nodes()) { - to_delete.push_back(n); - } - } - } - for (Node* n : to_delete) { - switch_nodes_.erase(n); - auto it = clustered_graph_.find(Representative(n)); - if (it != clustered_graph_.end()) { - it->second.switch_nodes.erase(n); - } - graph_->RemoveNode(n); - } -} - -Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, - const Cluster& merge_cluster, - const std::vector& outputs, +Status FunctionalizeCond::ExtractBody(const std::vector& switch_nodes, + const std::vector& merge_nodes, int input_edge, Graph* body) { - VLOG(2) << "ExtractBody for " << merge_cluster.representative - << " along edge " << input_edge; + VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge " + << input_edge; std::vector squash_src_outputs(graph_->num_node_ids(), false); std::vector node_map(graph_->num_node_ids(), nullptr); int arg_count = 0; - for (const auto* arg : cond_args.args) { + for (const auto* arg : switch_nodes) { DataType dtype = arg->input_type(0); TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(body, dtype, arg_count++)); @@ -1134,9 +917,9 @@ Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, } std::vector stack; - stack.reserve(outputs.size()); - for (int j = 0; j < outputs.size(); ++j) { - Node* node = outputs[j]; + stack.reserve(switch_nodes.size()); + for (int j = 0; j < merge_nodes.size(); ++j) { + Node* node = merge_nodes[j]; TF_ASSIGN_OR_RETURN(node_map.at(node->id()), BuildRetvalNode(body, node->output_type(0), /*index=*/j)); @@ -1147,7 +930,8 @@ Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, node_map.at(in->id()) = body->CopyNode(in); } - if (cond_args.args.find(in) == cond_args.args.end()) { + if (std::find(switch_nodes.begin(), switch_nodes.end(), in) == + switch_nodes.end()) { body->AddEdge(node_map.at(in->id()), in_edge->src_output(), node_map.at(node->id()), 0); } else { @@ -1162,12 +946,12 @@ Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, body); } -Status FunctionalizeCond::AddInputEdges(const CondArgs& cond_args, - Node* if_node) { +Status FunctionalizeCond::AddInputEdges(const std::vector& cond_args, + Node* predicate, Node* if_node) { VLOG(3) << "AddInputEdges for " << if_node->name(); int i = 0; - graph_->AddEdge(cond_args.conditional, 0, if_node, i++); - for (const Node* arg : cond_args.args) { + graph_->AddEdge(predicate, 0, if_node, i++); + for (const Node* arg : cond_args) { const Edge* in_edge; TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); if (in_edge->IsControlEdge()) { @@ -1204,176 +988,26 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, return Status::OK(); } -void FunctionalizeCond::RemoveMergeNodes(Cluster* merge_cluster) { - VLOG(3) << "RemoveMergeNodes for " << merge_cluster->representative; - // Remove all merge nodes now dead post extraction of If. - for (auto it = merge_cluster->merge_nodes.begin(); - it != merge_cluster->merge_nodes.end();) { - Node* node = *it; - graph_->RemoveNode(node); - merge_cluster->merge_nodes.erase(*it++); - } -} - -Status FunctionalizeCond::RemoveTrivialMerge(Cluster* merge_cluster) { - Cluster* switch_cluster = *merge_cluster->in_nodes.begin(); - if (switch_cluster->switch_nodes.empty()) { - return errors::FailedPrecondition( - "Not a trivial merge: no Switch node feeding into Merge node"); - } - - for (auto it = merge_cluster->merge_nodes.begin(); - it != merge_cluster->merge_nodes.end();) { - // We have the following structure: - // Op -> Switch -> Merge -> Consumer - // and we want to transform it to: - // Op -> Consumer - Node* merge_node = *it; - Node* switch_node; - const Edge* in = nullptr; - TF_RETURN_IF_ERROR(merge_node->input_node(0, &switch_node)); - TF_RETURN_IF_ERROR(switch_node->input_edge(0, &in)); - for (auto out : merge_node->out_edges()) { - int src_output = out->dst_input() == Graph::kControlSlot - ? Graph::kControlSlot - : in->src_output(); - graph_->AddEdge(in->src(), src_output, out->dst(), out->dst_input()); - } - graph_->RemoveNode(*it++); - } - RemoveUnusedArgs(switch_cluster->switch_nodes); - - return Status::OK(); -} - -Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) { - VLOG(1) << "ConvertMergeToXlaIf for " << merge_cluster->representative; - gtl::optional switch_cluster = GetSwitchCluster(*merge_cluster); - if (!switch_cluster.has_value()) { - return errors::FailedPrecondition( - "Merge cluster was not part of a simple conditional in the clustered " - "graph. Graph nodes in merge cluster ", - NodesToString(merge_cluster->merge_nodes)); - } - TF_ASSIGN_OR_RETURN(auto cond_args, - DetermineCondArgs(*merge_cluster, **switch_cluster)); - - // Sort the outputs by ID to produce more stable output. - std::vector outputs(merge_cluster->merge_nodes.begin(), - merge_cluster->merge_nodes.end()); - std::sort(outputs.begin(), outputs.end(), CondArgs::CondCmp()); +Status FunctionalizeCond::ConvertCorrespondingMergeToXlaIf( + const std::vector& switch_nodes, + const std::vector& merge_nodes, Node* predicate) { + VLOG(1) << "ConvertMergeToXlaIf for " << NodesToString(switch_nodes) << " -> " + << NodesToString(merge_nodes); // Extract bodies and builds a If operator. TF_ASSIGN_OR_RETURN(Node * if_node, - BuildAndAddXlaIfOp(cond_args, *merge_cluster, outputs)); - TF_RETURN_IF_ERROR(AddInputEdges(cond_args, if_node)); - TF_RETURN_IF_ERROR(AddOutputEdges(outputs, if_node)); - - // Remove the old nodes from the graph_ and contract the edges of the - // clustered graph. - for (auto in : merge_cluster->in_nodes) { - if (in != *switch_cluster) { - RemoveClusterNodes(in); - } - } - RemoveMergeNodes(merge_cluster); - RemoveUnusedArgs(cond_args.args); - auto in_nodes = merge_cluster->in_nodes; - for (auto it = in_nodes.begin(); it != in_nodes.end();) { - ContractEdge(*it++, merge_cluster); - } - ContractEdge(*switch_cluster, merge_cluster); - clusters_[if_node].Get() = ClusterHandle(merge_cluster->representative); + BuildAndAddXlaIfOp(switch_nodes, merge_nodes, predicate)); + TF_RETURN_IF_ERROR(AddInputEdges(switch_nodes, predicate, if_node)); + TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); return Status::OK(); } -std::vector> -FunctionalizeCond::SortedMergeNodes() { - VLOG(2) << "ProcessClusteredGraph"; - std::stack> stack; - for (auto& c : clustered_graph_) { - if (c.second.in_nodes.empty()) { - stack.push({0, &c.second}); - } - } - - // Perform a depth-first traversal of the clustered graph computing the - // switch-merge depth. - std::vector> queue; - std::unordered_set visited; - while (!stack.empty()) { - Cluster* n = stack.top().second; - size_t depth = stack.top().first; - stack.pop(); - - auto inserted = visited.insert(n); - if (!inserted.second) { - continue; - } - - size_t new_depth = depth; - if (!n->merge_nodes.empty()) { - queue.emplace_back(depth, n); - --new_depth; - } - if (!n->switch_nodes.empty()) { - ++new_depth; - } - for (Cluster* e : n->out_nodes) { - stack.emplace(new_depth, e); - } - } - - // Sort in reverse order of switch-merge depth with ties broken by the - // ClusterHandle. - std::sort(queue.begin(), queue.end(), - [](const std::pair& lhs, - const std::pair& rhs) { - return std::tie(lhs.first, lhs.second->representative) > - std::tie(rhs.first, rhs.second->representative); - }); - - return queue; -} - Status FunctionalizeCond::Functionalize(Graph* graph, FunctionLibraryDefinition* library) { VLOG(1) << "FunctionalizeCond::Functionalize"; FunctionalizeCond fc(graph, library); - fc.CreateClusters(); - if (fc.NoConditionals()) { - return Status::OK(); - } - fc.CreateClusteredGraph(); - - auto queue = fc.SortedMergeNodes(); - for (auto it = queue.begin(); it != queue.end();) { - Cluster* merge_cluster = (*it).second; - ++it; - if (merge_cluster->in_nodes.size() == 1) { - TF_RETURN_IF_ERROR(fc.RemoveTrivialMerge(merge_cluster)); - } else { - TF_RETURN_IF_ERROR(fc.ConvertMergeToXlaIf(merge_cluster)); - } - - // Contract newly Merge free merge_cluster with incoming nodes without - // Switch or Merge nodes. - std::vector in_nodes(merge_cluster->in_nodes.begin(), - merge_cluster->in_nodes.end()); - for (auto in : in_nodes) { - if (in->merge_nodes.empty() && in->switch_nodes.empty()) { - fc.ContractEdge(in, merge_cluster); - } - } - } - - if (!fc.switch_nodes_.empty()) { - return errors::Internal( - "Failed to functionalize control flow with Switch nodes remaining: ", - NodesToString(fc.switch_nodes_)); - } - return Status::OK(); + return fc.FunctionalizeInternal(); } } // namespace diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 01d2b282751f387cfa9c8887cdeb48090c96bff4..71f12a13339b9b5495631b8f9350579f6a0785a3 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -109,7 +109,7 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName("cond/Merge_If"), less, + auto if_op = ops::XlaIf(scope.WithOpName("cond/Less_If"), less, std::initializer_list{less, y, x}, then_fn, else_fn, {DT_INT32}); GraphDef expected; diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md new file mode 100644 index 0000000000000000000000000000000000000000..82b3b46a2f1e97001d1e0c6b993ec243170bc7d8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md @@ -0,0 +1,242 @@ +**Supported operators for device: XLA_CPU_JIT** + +Operator | Type Constraint +------------------------------------- | --------------- +`Abs` | `T={double,float,int32,int64}` +`Acosh` | `T={complex64,double,float}` +`Add` | `T={complex64,double,float,int32,int64}` +`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`All` | `Tidx={int32,int64}` +`Angle` | `Tout={double,float}`
`T={complex64}` +`Any` | `Tidx={int32,int64}` +`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={float}` +`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asinh` | `T={complex64,double,float}` +`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan2` | `T={double,float}` +`Atanh` | `T={complex64,double,float}` +`AvgPool` | `T={double,float}` +`AvgPool3D` | `T={double,float}` +`AvgPool3DGrad` | `T={double,float}` +`AvgPoolGrad` | `T={double,float}` +`BatchMatMul` | `T={complex64,double,float,int32}` +`BatchToSpace` | `Tidx={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BatchToSpaceND` | `Tcrops={int32,int64}`
`Tblock_shape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BitwiseAnd` | `T={int32,int64,uint32,uint64}` +`BitwiseOr` | `T={int32,int64,uint32,uint64}` +`BroadcastArgs` | `T={int32,int64}` +`BroadcastGradientArgs` | `T={int32,int64}` +`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Ceil` | `T={double,float}` +`Cholesky` | `T={complex64,double,float}` +`Complex` | `Tout={complex64}`
`T={double,float}` +`ComplexAbs` | `Tout={double,float}`
`T={complex64}` +`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ConcatOffset` | +`ConcatV2` | `Tidx={int32}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Conj` | `T={complex64}` +`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ControlTrigger` | +`Conv2D` | `T={float}` +`Conv2DBackpropFilter` | `T={float}` +`Conv2DBackpropInput` | `T={float}` +`Conv3D` | `T={double,float}` +`Conv3DBackpropFilterV2` | `T={double,float}` +`Conv3DBackpropInputV2` | `T={double,float}` +`Cos` | `T={complex64,double,float}` +`Cosh` | `T={complex64,double,float}` +`Cross` | `T={double,float,int32,int64,uint32,uint64}` +`Cumprod` | `Tidx={int32,int64}`
`T={float}` +`Cumsum` | `Tidx={int32,int64}`
`T={float}` +`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`DepthwiseConv2dNative` | `T={double,float}` +`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}` +`DepthwiseConv2dNativeBackpropInput` | `T={double,float}` +`Diag` | `T={complex64,double,float,int32,int64}` +`DiagPart` | `T={complex64,double,float,int32,int64}` +`Div` | `T={complex64,double,float,int32,int64}` +`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Elu` | `T={double,float}` +`EluGrad` | `T={double,float}` +`Equal` | `T={bool,complex64,double,float,int32,int64}` +`Exp` | `T={complex64,double,float}` +`ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Expm1` | `T={complex64,double,float}` +`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Floor` | `T={double,float}` +`FloorDiv` | `T={complex64,double,float,int32,int64}` +`FloorMod` | `T={double,float,int32,int64}` +`FusedBatchNorm` | `T={float}` +`FusedBatchNormGrad` | `T={float}` +`FusedBatchNormGradV2` | `U={float}`
`T={float}` +`FusedBatchNormV2` | `U={float}`
`T={float}` +`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Greater` | `T={double,float,int32,int64,uint32,uint64}` +`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Imag` | `Tout={double,float}`
`T={complex64}` +`Inv` | `T={complex64,double,float,int32,int64}` +`Invert` | `T={int32,int64,uint32,uint64}` +`InvertPermutation` | `T={int32}` +`IsFinite` | `T={double,float}` +`IsInf` | `T={double,float}` +`IsNan` | `T={double,float}` +`L2Loss` | `T={double,float}` +`LRN` | `T={float}` +`LRNGrad` | `T={float}` +`LeftShift` | `T={int32,int64,uint32,uint64}` +`Less` | `T={double,float,int32,int64,uint32,uint64}` +`LessEqual` | `T={double,float,int32,int64,uint32,uint64}` +`LinSpace` | `Tidx={int32,int64}`
`T={double,float}` +`Log` | `T={complex64,double,float}` +`Log1p` | `T={complex64,double,float}` +`LogSoftmax` | `T={double,float}` +`LogicalAnd` | +`LogicalNot` | +`LogicalOr` | +`MatMul` | `T={complex64,double,float}` +`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`MaxPool` | `T={double,float,int32,int64}` +`MaxPool3D` | `T={float}` +`MaxPool3DGrad` | `TInput={float}`
`T={float}` +`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Maximum` | `T={double,float,int32,int64}` +`Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Minimum` | `T={double,float,int32,int64}` +`MirrorPad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Mod` | `T={double,float,int32,int64}` +`Mul` | `T={complex64,double,float,int32,int64}` +`Multinomial` | `output_dtype={int32,int64}`
`T={double,float,int32,int64,uint32,uint64}` +`Neg` | `T={complex64,double,float,int32,int64}` +`NoOp` | +`NotEqual` | `T={bool,complex64,double,float,int32,int64}` +`OneHot` | `TI={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`OnesLike` | `T={bool,complex64,double,float,int32,int64}` +`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`PadV2` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pow` | `T={complex64,double,float,int32,int64}` +`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`QuantizeAndDequantizeV2` | `T={double,float}` +`RandomStandardNormal` | `dtype={float}` +`RandomUniform` | `T={int32,int64}`
`dtype={double,float}` +`RandomUniformInt` | `T={int32,int64}`
`Tout={int32,int64}` +`Range` | `Tidx={double,float,int32,int64}` +`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Real` | `Tout={double,float}`
`T={complex64}` +`RealDiv` | `T={complex64,double,float,int32,int64}` +`Reciprocal` | `T={complex64,double,float,int32,int64}` +`ReciprocalGrad` | `T={complex64,double,float}` +`Relu` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` +`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResourceApplyAdagrad` | `T={double,float}` +`ResourceApplyAdam` | `T={double,float}` +`ResourceApplyFtrl` | `T={double,float}` +`ResourceApplyFtrlV2` | `T={double,float}` +`ResourceApplyGradientDescent` | `T={double,float}` +`ResourceApplyMomentum` | `T={double,float}` +`ResourceApplyRMSProp` | `T={double,float}` +`ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` +`ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` +`RightShift` | `T={int32,int64,uint32,uint64}` +`Rint` | `T={double,float}` +`Round` | `T={complex64,double,float,int32,int64}` +`Rsqrt` | `T={complex64,double,float}` +`RsqrtGrad` | `T={complex64,double,float}` +`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Selu` | `T={double,float}` +`SeluGrad` | `T={double,float}` +`Shape` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ShapeN` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sigmoid` | `T={complex64,double,float}` +`SigmoidGrad` | `T={complex64,double,float}` +`Sign` | `T={complex64,double,float,int32,int64}` +`Sin` | `T={complex64,double,float}` +`Sinh` | `T={complex64,double,float}` +`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Softmax` | `T={double,float}` +`SoftmaxCrossEntropyWithLogits` | `T={double,float}` +`Softplus` | `T={double,float,int32,int64,uint32,uint64}` +`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Softsign` | `T={double,float,int32,int64,uint32,uint64}` +`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}` +`SpaceToBatch` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToBatchND` | `Tblock_shape={int32,int64}`
`Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SparseMatMul` | `Tb={float}`
`Ta={float}` +`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`
`T={double,float}` +`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SplitV` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sqrt` | `T={complex64,double,float}` +`SqrtGrad` | `T={complex64,double,float}` +`Square` | `T={complex64,double,float,int32,int64}` +`SquaredDifference` | `T={complex64,double,float,int32,int64}` +`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackCloseV2` | +`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StatelessRandomNormal` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StatelessRandomUniform` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSlice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSliceGrad` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sub` | `T={complex64,double,float,int32,int64}` +`Sum` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tan` | `T={complex64,double,float,int32,int64}` +`Tanh` | `T={complex64,double,float}` +`TanhGrad` | `T={complex64,double,float}` +`TensorArrayCloseV3` | +`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGradV3` | +`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArraySizeV3` | +`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tile` | `Tmultiples={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Transpose` | `Tperm={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TruncateDiv` | `T={complex64,double,float,int32,int64}` +`TruncateMod` | `T={double,float,int32,int64}` +`TruncatedNormal` | `T={int32,int64}`
`dtype={double,float}` +`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`
`Tindices={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`VarIsInitializedOp` | +`VariableShape` | `out_type={int32,int64}` +`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` + +To regenerate this table, run: + +```shell +bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_CPU_JIT +``` diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md new file mode 100644 index 0000000000000000000000000000000000000000..d4b7621ad2858fe17e93d292dd807e4f7c1c336b --- /dev/null +++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md @@ -0,0 +1,238 @@ +**Supported operators for device: XLA_GPU_JIT** + +Operator | Type Constraint +------------------------------------- | --------------- +`Abs` | `T={double,float,int32,int64}` +`Acosh` | `T={complex64,double,float}` +`Add` | `T={complex64,double,float,int32,int64}` +`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`All` | `Tidx={int32,int64}` +`Angle` | `Tout={double,float}`
`T={complex64}` +`Any` | `Tidx={int32,int64}` +`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asinh` | `T={complex64,double,float}` +`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan2` | `T={double,float}` +`Atanh` | `T={complex64,double,float}` +`AvgPool` | `T={double,float}` +`AvgPool3D` | `T={double,float}` +`AvgPool3DGrad` | `T={double,float}` +`AvgPoolGrad` | `T={double,float}` +`BatchMatMul` | `T={complex64,double,float,int32}` +`BatchToSpace` | `Tidx={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BatchToSpaceND` | `Tcrops={int32,int64}`
`Tblock_shape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BitwiseAnd` | `T={int32,int64,uint32,uint64}` +`BitwiseOr` | `T={int32,int64,uint32,uint64}` +`BroadcastArgs` | `T={int32,int64}` +`BroadcastGradientArgs` | `T={int32,int64}` +`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Ceil` | `T={double,float}` +`Cholesky` | `T={complex64,double,float}` +`Complex` | `Tout={complex64}`
`T={double,float}` +`ComplexAbs` | `Tout={double,float}`
`T={complex64}` +`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ConcatOffset` | +`ConcatV2` | `Tidx={int32}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Conj` | `T={complex64}` +`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ControlTrigger` | +`Conv2D` | `T={float}` +`Conv2DBackpropFilter` | `T={float}` +`Conv2DBackpropInput` | `T={float}` +`Conv3D` | `T={double,float}` +`Conv3DBackpropFilterV2` | `T={double,float}` +`Conv3DBackpropInputV2` | `T={double,float}` +`Cos` | `T={complex64,double,float}` +`Cosh` | `T={complex64,double,float}` +`Cross` | `T={double,float,int32,int64,uint32,uint64}` +`Cumprod` | `Tidx={int32,int64}`
`T={float}` +`Cumsum` | `Tidx={int32,int64}`
`T={float}` +`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`DepthwiseConv2dNative` | `T={double,float}` +`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}` +`DepthwiseConv2dNativeBackpropInput` | `T={double,float}` +`Diag` | `T={complex64,double,float,int32,int64}` +`DiagPart` | `T={complex64,double,float,int32,int64}` +`Div` | `T={complex64,double,float,int32,int64}` +`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Elu` | `T={double,float}` +`EluGrad` | `T={double,float}` +`Equal` | `T={bool,complex64,double,float,int32,int64}` +`Exp` | `T={complex64,double,float}` +`ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Expm1` | `T={complex64,double,float}` +`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Floor` | `T={double,float}` +`FloorDiv` | `T={complex64,double,float,int32,int64}` +`FloorMod` | `T={double,float,int32,int64}` +`FusedBatchNorm` | `T={float}` +`FusedBatchNormGrad` | `T={float}` +`FusedBatchNormGradV2` | `U={float}`
`T={float}` +`FusedBatchNormV2` | `U={float}`
`T={float}` +`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Greater` | `T={double,float,int32,int64,uint32,uint64}` +`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Imag` | `Tout={double,float}`
`T={complex64}` +`Inv` | `T={complex64,double,float,int32,int64}` +`Invert` | `T={int32,int64,uint32,uint64}` +`InvertPermutation` | `T={int32}` +`IsFinite` | `T={double,float}` +`IsInf` | `T={double,float}` +`IsNan` | `T={double,float}` +`L2Loss` | `T={double,float}` +`LRN` | `T={float}` +`LRNGrad` | `T={float}` +`LeftShift` | `T={int32,int64,uint32,uint64}` +`Less` | `T={double,float,int32,int64,uint32,uint64}` +`LessEqual` | `T={double,float,int32,int64,uint32,uint64}` +`LinSpace` | `Tidx={int32,int64}`
`T={double,float}` +`Log` | `T={complex64,double,float}` +`Log1p` | `T={complex64,double,float}` +`LogSoftmax` | `T={double,float}` +`LogicalAnd` | +`LogicalNot` | +`LogicalOr` | +`MatMul` | `T={complex64,double,float}` +`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`MaxPool` | `T={double,float,int32,int64}` +`MaxPool3D` | `T={float}` +`MaxPool3DGrad` | `TInput={float}`
`T={float}` +`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Maximum` | `T={double,float,int32,int64}` +`Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Minimum` | `T={double,float,int32,int64}` +`MirrorPad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Mod` | `T={double,float,int32,int64}` +`Mul` | `T={complex64,double,float,int32,int64}` +`Multinomial` | `output_dtype={int32,int64}`
`T={double,float,int32,int64,uint32,uint64}` +`Neg` | `T={complex64,double,float,int32,int64}` +`NoOp` | +`NotEqual` | `T={bool,complex64,double,float,int32,int64}` +`OneHot` | `TI={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`OnesLike` | `T={bool,complex64,double,float,int32,int64}` +`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`PadV2` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pow` | `T={complex64,double,float,int32,int64}` +`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`QuantizeAndDequantizeV2` | `T={double,float}` +`Range` | `Tidx={double,float,int32,int64}` +`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Real` | `Tout={double,float}`
`T={complex64}` +`RealDiv` | `T={complex64,double,float,int32,int64}` +`Reciprocal` | `T={complex64,double,float,int32,int64}` +`ReciprocalGrad` | `T={complex64,double,float}` +`Relu` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` +`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResourceApplyAdagrad` | `T={double,float}` +`ResourceApplyAdam` | `T={double,float}` +`ResourceApplyFtrl` | `T={double,float}` +`ResourceApplyFtrlV2` | `T={double,float}` +`ResourceApplyGradientDescent` | `T={double,float}` +`ResourceApplyMomentum` | `T={double,float}` +`ResourceApplyRMSProp` | `T={double,float}` +`ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` +`ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` +`RightShift` | `T={int32,int64,uint32,uint64}` +`Rint` | `T={double,float}` +`Round` | `T={complex64,double,float,int32,int64}` +`Rsqrt` | `T={complex64,double,float}` +`RsqrtGrad` | `T={complex64,double,float}` +`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Selu` | `T={double,float}` +`SeluGrad` | `T={double,float}` +`Shape` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ShapeN` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sigmoid` | `T={complex64,double,float}` +`SigmoidGrad` | `T={complex64,double,float}` +`Sign` | `T={complex64,double,float,int32,int64}` +`Sin` | `T={complex64,double,float}` +`Sinh` | `T={complex64,double,float}` +`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Softmax` | `T={double,float}` +`SoftmaxCrossEntropyWithLogits` | `T={double,float}` +`Softplus` | `T={double,float,int32,int64,uint32,uint64}` +`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Softsign` | `T={double,float,int32,int64,uint32,uint64}` +`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}` +`SpaceToBatch` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToBatchND` | `Tblock_shape={int32,int64}`
`Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SparseMatMul` | `Tb={float}`
`Ta={float}` +`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`
`T={double,float}` +`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SplitV` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sqrt` | `T={complex64,double,float}` +`SqrtGrad` | `T={complex64,double,float}` +`Square` | `T={complex64,double,float,int32,int64}` +`SquaredDifference` | `T={complex64,double,float,int32,int64}` +`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackCloseV2` | +`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StatelessRandomNormal` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StatelessRandomUniform` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSlice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSliceGrad` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sub` | `T={complex64,double,float,int32,int64}` +`Sum` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tan` | `T={complex64,double,float,int32,int64}` +`Tanh` | `T={complex64,double,float}` +`TanhGrad` | `T={complex64,double,float}` +`TensorArrayCloseV3` | +`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGradV3` | +`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArraySizeV3` | +`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tile` | `Tmultiples={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Transpose` | `Tperm={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TruncateDiv` | `T={complex64,double,float,int32,int64}` +`TruncateMod` | `T={double,float,int32,int64}` +`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`
`Tindices={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`VarIsInitializedOp` | +`VariableShape` | `out_type={int32,int64}` +`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` + +To regenerate this table, run: + +```shell +bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_GPU_JIT +``` diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 13d06177f0fe2eb1a71e5cf684d74d87e263cfc5..3e24cf042e17ad4e212d82ac4f24fec06a6c780f 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -19,6 +19,7 @@ tf_kernel_library( "binary_ops.cc", "cast_op.cc", "categorical_op.cc", + "cholesky_op.cc", "concat_op.cc", "const_op.cc", "conv_ops.cc", @@ -34,6 +35,7 @@ tf_kernel_library( "gather_op.cc", "gather_op_helpers.h", "identity_op.cc", + "image_resize_ops.cc", "index_ops.cc", "l2loss_op.cc", "lrn_ops.cc", @@ -53,17 +55,20 @@ tf_kernel_library( "reshape_op.cc", "retval_op.cc", "reverse_op.cc", + "scan_ops.cc", "segment_reduction_ops.cc", "select_op.cc", "sendrecv_ops.cc", "sequence_ops.cc", "shape_op.cc", + "shape_util.cc", "slice_op.cc", "softmax_op.cc", "spacetobatch_op.cc", "spacetodepth_op.cc", "split_op.cc", "stack_ops.cc", + "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", "tile_ops.cc", @@ -76,12 +81,17 @@ tf_kernel_library( hdrs = [ "gather_op.h", "index_ops.h", + "shape_util.h", ], deps = [ ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/lib:batch_dot", + "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", + "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -90,8 +100,11 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/core:framework", + "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:linalg_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:stateless_random_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:concat_lib", "//tensorflow/core/kernels:constant_op", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 73ccc151c1d6bdf70105badd962903297f090abe..a015b8e0e8949f8aaa03a78b0f88b7ea8d6aaa1c 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -13,11 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// XLA-specific BatchMatMul Op. -// The current implementation simply unrolls the computation along the batch -// dimension. -// TODO(dominikg,phawkins): Use a real batched matmul instead of unrolling. - +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -32,110 +28,10 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape x_shape = ctx->InputShape(0); - const TensorShape y_shape = ctx->InputShape(1); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - OP_REQUIRES(ctx, x_shape.dims() == y_shape.dims(), - errors::InvalidArgument("In[0] and In[1] has different ndims: ", - x_shape.DebugString(), " vs. ", - y_shape.DebugString())); - const int ndims = x_shape.dims(); - OP_REQUIRES( - ctx, ndims >= 2, - errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims)); - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector dimensions; - int batch_count = 1; - for (int i = 0; i < ndims - 2; ++i) { - OP_REQUIRES( - ctx, x_shape.dim_size(i) == y_shape.dim_size(i), - errors::InvalidArgument("In[0].dim(", i, ") and In[1].dim(", i, - ") must be the same: ", x_shape.DebugString(), - " vs ", y_shape.DebugString())); - dimensions.push_back(x_shape.dim_size(i)); - batch_count *= x_shape.dim_size(i); - } - - int x_inner_dim = adj_x_ ? (ndims - 2) : (ndims - 1); - int y_inner_dim = adj_y_ ? (ndims - 1) : (ndims - 2); - OP_REQUIRES( - ctx, x_shape.dim_size(x_inner_dim) == y_shape.dim_size(y_inner_dim), - errors::InvalidArgument( - "In[0] mismatch In[1] shape: ", x_shape.dim_size(x_inner_dim), - " vs. ", y_shape.dim_size(y_inner_dim), ": ", x_shape.DebugString(), - " ", y_shape.DebugString(), " ", adj_x_, " ", adj_y_)); - - int x_outer_dim = adj_x_ ? (ndims - 1) : (ndims - 2); - int y_outer_dim = adj_y_ ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dim_size(x_outer_dim)); - dimensions.push_back(y_shape.dim_size(y_outer_dim)); - - xla::ComputationBuilder* builder = ctx->builder(); - - xla::ComputationDataHandle x_handle = ctx->Input(0); - if (BaseType(input_type(0)) == DT_COMPLEX64 && adj_x_) { - x_handle = builder->Conj(x_handle); - } - xla::ComputationDataHandle y_handle = ctx->Input(1); - if (BaseType(input_type(1)) == DT_COMPLEX64 && adj_y_) { - y_handle = builder->Conj(y_handle); - } - - // Reshape input tensors into 3D tensors by flattening the batch - // dimensions. This makes it easier to unroll the batch dimension. - auto x_flat = - builder->Reshape(x_handle, {batch_count, x_shape.dim_size(ndims - 2), - x_shape.dim_size(ndims - 1)}); - auto y_flat = - builder->Reshape(y_handle, {batch_count, y_shape.dim_size(ndims - 2), - y_shape.dim_size(ndims - 1)}); - - // Slice batches into individual matrices and multiply them. - std::vector out_slices; - for (int i = 0; i < batch_count; ++i) { - // Slice off individual matrices and reshape to 2D tensors. - auto x_slice = builder->Slice( - x_flat, {i, 0, 0}, - {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}, - {1, 1, 1}); - x_slice = builder->Reshape( - x_slice, {x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}); - auto y_slice = builder->Slice( - y_flat, {i, 0, 0}, - {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}, - {1, 1, 1}); - y_slice = builder->Reshape( - y_slice, {y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}); - - // Transpose if needed. - auto lhs = adj_x_ ? builder->Transpose(x_slice, {1, 0}) : x_slice; - auto rhs = adj_y_ ? builder->Transpose(y_slice, {1, 0}) : y_slice; - - // Multiply matrices and add an outer singleton dimension to the output - // so we can concatenate along the flattened batch dimension later. - auto out = builder->Dot(lhs, rhs); - out = builder->Reshape(out, - {1, dimensions[ndims - 2], dimensions[ndims - 1]}); - out_slices.push_back(out); - } - - // Concatenate output slices and reshape to original number of dimensions. - xla::ComputationDataHandle data; - if (out_slices.empty()) { - // It is illegal to pass an empty list to ConcatInDim. - // The batch count is empty, so both inputs must have zero elements. - // Arbitrarily use the left input as the argument to Reshape(). - data = x_handle; - } else { - data = builder->ConcatInDim(out_slices, 0); - } - data = builder->Reshape(data, dimensions); - - ctx->SetOutput(0, data); + auto result = + BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), adj_x_, adj_y_); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 248e9d111e556dcdd75581aa6562a66fc8b57063..a249b1869f547f8e5aa725f9f5cf391b10429928 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // XLA implementation of BatchNorm operations. -#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,43 +26,63 @@ namespace { class FusedBatchNormOp : public XlaOpKernel { public: explicit FusedBatchNormOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - TensorFormat tensor_format; - if (ctx->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(ctx, FormatFromString(data_format, &tensor_format), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES( - ctx, (tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW), - errors::InvalidArgument("Not supported format")); - feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format); - } + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES( + ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format: ", data_format_str)); + OP_REQUIRES(ctx, + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + errors::InvalidArgument( + "Unsupported data format ", ToString(data_format_), + "; supported formats are NHWC and NCHW")); } void Compile(XlaOpKernelContext* ctx) override { + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(0), &input_type)); + xla::PrimitiveType scale_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); + + xla::ComputationBuilder* builder = ctx->builder(); + + xla::ComputationDataHandle input = ctx->Input(0); + TensorShape input_shape = ctx->InputShape(0); + + int feature_index = + GetTensorFeatureDimIndex(input_shape.dims(), data_format_); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. As a workaround, cast everything to the statistics type (which + // may be more precise than the input type). + input = builder->ConvertElementType(input, scale_type); + if (is_training_) { - xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining( - ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, - feature_index_); + xla::ComputationDataHandle output = builder->BatchNormTraining( + input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the // calculated mean and variance. - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); - } + ctx->SetOutput(0, builder->ConvertElementType( + builder->GetTupleElement(output, 0), input_type)); + ctx->SetOutput(1, builder->GetTupleElement(output, 1)); + ctx->SetOutput(2, builder->GetTupleElement(output, 2)); + // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved // space 1 & 2". They are used to pass the per-batch mean and // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + ctx->SetOutput(3, builder->GetTupleElement(output, 1)); + ctx->SetOutput(4, builder->GetTupleElement(output, 2)); } else { - xla::ComputationDataHandle output = ctx->builder()->BatchNormInference( - ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3), - ctx->Input(4), epsilon_, feature_index_); - ctx->SetOutput(0, output); + xla::ComputationDataHandle output = builder->BatchNormInference( + input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), + epsilon_, feature_index); + ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); // Directly send input to output as mean and variance in inference mode. ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(2, ctx->Input(4)); @@ -73,55 +93,113 @@ class FusedBatchNormOp : public XlaOpKernel { private: float epsilon_; - int64 feature_index_; + TensorFormat data_format_; bool is_training_; }; REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); +REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp); class FusedBatchNormGradOp : public XlaOpKernel { public: explicit FusedBatchNormGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - bool is_training; - OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training)); - CHECK(is_training) << "FusedBatchNormGradOp with is_training=False cannot " - "be used with XLA for now!"; - TensorFormat tensor_format; - if (ctx->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(ctx, FormatFromString(data_format, &tensor_format), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES( - ctx, (tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW), - errors::InvalidArgument("Not supported format")); - feature_index_ = GetTensorFeatureDimIndex(4, tensor_format); - } + OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES( + ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format: ", data_format_str)); + OP_REQUIRES(ctx, + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + errors::InvalidArgument( + "Unsupported data format ", ToString(data_format_), + "; supported formats are NHWC and NCHW")); } void Compile(XlaOpKernelContext* ctx) override { - auto grad_output = ctx->Input(0); - auto activation = ctx->Input(1); + xla::ComputationBuilder* b = ctx->builder(); + + auto grad_backprop = ctx->Input(0); + auto activations = ctx->Input(1); auto scale = ctx->Input(2); auto mean = ctx->Input(3); auto var = ctx->Input(4); - xla::ComputationDataHandle output = ctx->builder()->BatchNormGrad( - activation, scale, mean, var, grad_output, epsilon_, feature_index_); - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); + TensorShape input_shape = ctx->InputShape(0); + int feature_index = + GetTensorFeatureDimIndex(input_shape.dims(), data_format_); + + DataType input_dtype = ctx->input_type(0); + DataType scale_dtype = ctx->input_type(2); + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_dtype, &input_type)); + xla::PrimitiveType scale_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(scale_dtype, &scale_type)); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. For now, cast everything to the statistics type (which + // may be more precise than the input type). + grad_backprop = b->ConvertElementType(grad_backprop, scale_type); + activations = b->ConvertElementType(activations, scale_type); + + xla::ComputationDataHandle x_backprop; + xla::ComputationDataHandle scale_backprop; + xla::ComputationDataHandle offset_backprop; + if (is_training_) { + xla::ComputationDataHandle output = + b->BatchNormGrad(activations, scale, mean, var, grad_backprop, + epsilon_, feature_index); + + x_backprop = b->GetTupleElement(output, 0); + scale_backprop = b->GetTupleElement(output, 1); + offset_backprop = b->GetTupleElement(output, 2); + } else { + // Reduce over all dimensions except the feature dim. + std::vector reduction_dims(input_shape.dims() - 1); + std::iota(reduction_dims.begin(), reduction_dims.begin() + feature_index, + 0); + std::iota(reduction_dims.begin() + feature_index, reduction_dims.end(), + feature_index + 1); + // offset_backprop = sum(y_backprop) + // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + + // epsilon)) + // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) + offset_backprop = + b->Reduce(grad_backprop, XlaHelpers::Zero(b, scale_dtype), + *ctx->GetOrCreateAdd(scale_dtype), reduction_dims); + + // scratch1 = rsqrt(pop_var + epsilon) + auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); + auto scratch1 = + b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); + + // scratch2 = sum(y_backprop * (x - mean)) + auto scratch2 = b->Reduce( + b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})), + XlaHelpers::Zero(b, scale_dtype), *ctx->GetOrCreateAdd(scale_dtype), + reduction_dims); + + x_backprop = + b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); + scale_backprop = b->Mul(scratch1, scratch2); } - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + + ctx->SetOutput(0, b->ConvertElementType(x_backprop, input_type)); + ctx->SetOutput(1, scale_backprop); + ctx->SetOutput(2, offset_backprop); + ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); + ctx->SetConstantOutput(4, Tensor(scale_dtype, {})); } private: + TensorFormat data_format_; float epsilon_; - int64 feature_index_; + bool is_training_; }; REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); +REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 1de91924326464338352b1ac9edf77141f25ad35..2436a6074a11ad66387b232dd1c5aa135875bfc3 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" namespace tensorflow { namespace { @@ -75,7 +76,7 @@ static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b, auto abs_y = b->Abs(y); auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one)); auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y)); - if (dtype == DT_FLOAT || dtype == DT_DOUBLE) { + if (DataTypeIsFloating(dtype)) { result = b->Floor(result); } return result; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_namespace_compat.h b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc similarity index 50% rename from tensorflow/core/distributed_runtime/rpc/grpc_namespace_compat.h rename to tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index c178927f5d5411e30bee2470b8b544ff76c28396..87d858f763560be454c162e0cf40307c68217663 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_namespace_compat.h +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -13,20 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_NAMESPACE_COMPAT_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_NAMESPACE_COMPAT_H_ - -// This file is a transitional place-holder until gRPC versions consistently -// use namespace grpc::internal for library-internal structures - -namespace grpc { -// ensure internal namespace exists -namespace internal { -// bring in contents of external namespace -using namespace ::grpc; -} // namespace internal -// bring in contents of internal namespace -using namespace internal; -} // namespace grpc - -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_NAMESPACE_COMPAT_H_ +#include "tensorflow/compiler/tf2xla/lib/cholesky.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class CholeskyOp : public XlaOpKernel { + public: + explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + auto result = Cholesky(ctx->builder(), ctx->Input(0)); + if (!result.ok()) { + ctx->SetStatus(result.status()); + return; + } + ctx->SetOutput(0, result.ValueOrDie()); + } +}; + +REGISTER_XLA_OP(Name("Cholesky"), CholeskyOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 9833323d851e00e7ca76d0b39cd2b216748a17fa..8f78b4c8f90cf00d5fa9ba71a78bb1c0fe280dc6 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -40,6 +40,11 @@ class ConstOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape shape(proto_.tensor_shape()); + if (proto_.dtype() == DT_STRING) { + LOG(WARNING) << "Not computing Const of type DT_STRING"; + ctx->SetInvalidOutput(0); + return; + } xla::ComputationBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 885f716afafca7ba23770e38f6693eed1ba50982..aaddbe811c6fbf6da296640eb5a75e82b2fedcfa 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -46,72 +46,130 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution( return expanded_shape; } +// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. +xla::ComputationDataHandle CreateExpandedZero( + const TensorShape& filter_shape, DataType dtype, + xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + return builder->Broadcast(XlaHelpers::Zero(builder, dtype), + expanded_filter_shape.dim_sizes()); +} + +// Create a mask for depthwise convolution that will make a normal convolution +// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] +// depthwise filter this returns a [2, 2, 3, 6] tesnsor +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// The first step is to create a one tensor, A, that is [3] +// 0 1 2 +// +// and another tensor, B, that is [3 * 2] +// 0 1 2 3 4 5 +// +// and divide B it by 2 to get +// 0 0 1 1 2 2 +// +// then we broadcast the B to [2, 2, 3, 3 * 2] +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and broadcasted B in dimension 2 amd return the result at +// the beginning of the comment. +xla::ComputationDataHandle CreateExpandedFilterMask( + const TensorShape& filter_shape, xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); + + // Create a M sized linspace and an M*N sized linspace that will be + // broadcasted into perpendicular dimensions and compared. + xla::ComputationDataHandle input_feature_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, + &input_feature_iota)); + xla::ComputationDataHandle expanded_feature_iota; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, + input_feature * depthwise_multiplier, + &expanded_feature_iota)); + + // Divide the M*N sized linspace by the depthwise_multiplier to create + // [0 0 1 1 2 2] in the example in the function comment. + expanded_feature_iota = + builder->Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); + + // Broadcast the N*M linspace to [H, W, ..., M, M*N]. + auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); + expanded_feature_broadcast_dims.pop_back(); + auto broadcasted_expanded_feature_iota = builder->Broadcast( + expanded_feature_iota, expanded_feature_broadcast_dims); + + // Compare the broadcasted linspace to the input feature linspace in the + // input feature dimension to create a diagonal predicate. + return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dims() - 2}); +} + // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding // zeros for the cross-depth filters. Used to build a depthwise convolution. xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter, xla::ComputationBuilder* builder) { - // Filter has shape [H, W, ..., M, N] - // Dilate to [H, W, ..., M*M, N] using M inter-element padding, and then - // reshape to [H, W, ..., M, M*N]. - int num_spatial_dims = filter_shape.dims() - 2; - const int64 in_depth = filter_shape.dim_size(num_spatial_dims); - xla::PaddingConfig padding = xla::MakeNoPaddingConfig(filter_shape.dims()); - padding.mutable_dimensions(num_spatial_dims)->set_interior_padding(in_depth); - auto dilated_filter = - builder->Pad(filter, XlaHelpers::Zero(builder, dtype), padding); - + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return builder->Reshape(dilated_filter, expanded_filter_shape.dim_sizes()); + + // Create a [H, W, ..., 1, N*M] reshape of the filter. + TensorShape implicit_broadcast_filter_shape = expanded_filter_shape; + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 2, 1); + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 1, + depthwise_multiplier * input_feature); + auto implicit_broadcast_filter = + builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); + + // Broadcast the filter to [H, W, ..., M, M*N]. + auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); + auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero); + + // If the filter mask is set, choose the broadcasted filter, othwerwise, + // choose zero. + return builder->Select(CreateExpandedFilterMask(filter_shape, builder), + expanded_filter, expanded_zero); } // Inverse of ExpandFilterForDepthwiseConvolution. xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( - const TensorShape& filter_shape, DataType dtype, + XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter_backprop, xla::ComputationBuilder* builder) { - int num_spatial_dims = filter_shape.dims() - 2; - - // Reshape to [H, W, ..., M*M, N] - TensorShape shape = filter_shape; - int64 in_depth = filter_shape.dim_size(num_spatial_dims); - shape.set_dim(num_spatial_dims, in_depth * in_depth); - auto reshaped = builder->Reshape(filter_backprop, shape.dim_sizes()); - - std::vector zeros(filter_shape.dims()); - std::vector strides(filter_shape.dims(), 1LL); - strides[num_spatial_dims] = in_depth + 1; - return builder->Slice(reshaped, zeros, shape.dim_sizes(), strides); - - // Alternate implementation for backends without strided Slice() support. - // TODO(phawkins): Remove when all backends support strided slice. - // // Pad [..., M * (M + 1), N] - // xla::PaddingConfig config = - // xla::MakeNoPaddingConfig(filter_shape.dims()); - // config.mutable_dimensions(num_spatial_dims) - // ->set_edge_padding_high(in_depth); - // auto zero = XlaHelpers::Zero(builder, dtype); - // auto padded = builder->Pad(reshaped, zero, config); - // - // // Reshape to [..., M, M + 1, N] - // shape = filter_shape; - // shape.set_dim(num_spatial_dims, in_depth); - // shape.set_dim(num_spatial_dims + 1, in_depth + 1); - // int64 out_depth = filter_shape.dim_size(num_spatial_dims + 1); - // shape.AddDim(out_depth); - // reshaped = builder->Reshape(padded, shape.dim_sizes()); - // - // // Slice to [..., M, 1, N] - // std::vector zeros(shape.dims()); - // std::vector strides(shape.dims(), 1LL); - // shape.set_dim(num_spatial_dims + 1, 1); - // auto sliced = builder->Slice(reshaped, zeros, shape.dim_sizes(), - // strides); - // - // // Reshape to [..., M, N] - // return builder->Reshape(sliced, filter_shape.dim_sizes()); + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + auto masked_expanded_filter = builder->Select( + CreateExpandedFilterMask(filter_shape, builder), filter_backprop, + CreateExpandedZero(filter_shape, dtype, builder)); + return builder->Reshape( + builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), + *ctx->GetOrCreateAdd(dtype), + {expanded_filter_shape.dims() - 2}), + filter_shape.dim_sizes()); } class ConvOp : public XlaOpKernel { @@ -121,6 +179,7 @@ class ConvOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); @@ -144,6 +203,23 @@ class ConvOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + const TensorShape input_shape = ctx->InputShape(0); // Input filter is of the following dimensions: // [ filter_rows, filter_cols, ..., in_depth, out_depth] @@ -184,10 +260,11 @@ class ConvOp : public XlaOpKernel { dims.set_input_feature_dimension(feature_dim); dims.set_output_feature_dimension(feature_dim); for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dims.add_spatial_dimensions(input_dim); + const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + dims.add_input_spatial_dimensions(dim); dims.add_kernel_spatial_dimensions(i); - window_strides.push_back(strides_.at(input_dim)); + dims.add_output_spatial_dimensions(dim); + window_strides.push_back(strides_.at(dim)); } dims.set_kernel_input_feature_dimension(num_spatial_dims_); dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); @@ -203,6 +280,7 @@ class ConvOp : public XlaOpKernel { protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -240,6 +318,7 @@ class ConvBackpropInputOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -262,6 +341,23 @@ class ConvBackpropInputOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); @@ -302,9 +398,10 @@ class ConvBackpropInputOp : public XlaOpKernel { std::vector lhs_dilation(num_spatial_dims_); std::vector ones(num_spatial_dims_, 1); for (int i = 0; i < num_spatial_dims_; ++i) { - dnums.add_spatial_dimensions( - GetTensorSpatialDimIndex(num_dims(), data_format_, i)); + int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(i); + dnums.add_output_spatial_dimensions(dim); kernel_spatial_dims[i] = i; padding[i] = {dims.spatial_dims[i].pad_before, @@ -334,6 +431,7 @@ class ConvBackpropInputOp : public XlaOpKernel { protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -371,6 +469,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -390,6 +489,23 @@ class ConvBackpropFilterOp : public XlaOpKernel { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + const TensorShape activations_shape = ctx->InputShape(0); TensorShape filter_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); @@ -424,9 +540,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { // Swap n_dim and c_dim in the activations. dnums.set_input_batch_dimension(c_dim); - dnums.set_output_batch_dimension(c_dim); dnums.set_input_feature_dimension(n_dim); - dnums.set_output_feature_dimension(n_dim); // The gradients become the RHS of the convolution. // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] @@ -438,9 +552,16 @@ class ConvBackpropFilterOp : public XlaOpKernel { std::vector rhs_dilation(num_spatial_dims_); std::vector ones(num_spatial_dims_, 1); + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_spatial_dimensions(dim); + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(num_spatial_dims_); + dnums.set_output_feature_dimension(num_spatial_dims_ + 1); + + for (int i = 0; i < num_spatial_dims_; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); // We will also need to pad the input with zeros such that after the @@ -498,31 +619,17 @@ class ConvBackpropFilterOp : public XlaOpKernel { /*window_strides=*/ones, padding, /*lhs_dilation=*/ones, rhs_dilation, dnums); - // The layout of filter_backprop will match the layout of - // padded_activations - // and so will have layout: [out_feature, h, w, ..., in_feature] - // Tensorflow filter shape is [ H, W, ..., inC, outC ], so we transpose the - // output. - std::vector transpose_dims; - transpose_dims.reserve(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - transpose_dims.push_back(dnums.spatial_dimensions(i)); - } - transpose_dims.push_back(c_dim); - transpose_dims.push_back(n_dim); - xla::ComputationDataHandle filter_backprop_reshaped = - b->Transpose(filter_backprop, transpose_dims); - if (depthwise_) { - filter_backprop_reshaped = ContractFilterForDepthwiseBackprop( - filter_shape, ctx->input_type(0), filter_backprop_reshaped, b); + filter_backprop = ContractFilterForDepthwiseBackprop( + ctx, filter_shape, ctx->input_type(0), filter_backprop, b); } - ctx->SetOutput(0, filter_backprop_reshaped); + ctx->SetOutput(0, filter_backprop); } protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index a4ea65ea89e348cb77412efb0c5c0fcb1a9f33f3..96d7809f7995634b6bc31ab801b93526d9da7e6f 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -23,6 +24,16 @@ namespace { class DepthToSpaceOp : public XlaOpKernel { public: explicit DepthToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_), + "; expected formats NHWC or NCHW")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -31,18 +42,79 @@ class DepthToSpaceOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_tensor_shape = ctx->InputShape(0); - // The input is presumed to be [batch, height, width, depth] int input_rank = input_tensor_shape.dims(); static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, - errors::InvalidArgument("Input rank should be: ", kRequiredDims, - " instead of: ", input_rank)); + errors::InvalidArgument("Input rank should be ", kRequiredDims, + "; got: ", input_rank)); const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle input = ctx->Input(0); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + + std::vector reshaped_shape; + std::vector transpose_order; + std::vector output_shape; + reshaped_shape.reserve(input_rank); + transpose_order.reserve(input_rank); + output_shape.reserve(input_rank); + if (data_format_ == FORMAT_NHWC) { + reshaped_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[1 + i]); + } + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(block_size_); + block_elems *= block_size_; + } + reshaped_shape.push_back(input_shape[feature_dim] / block_elems); + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i + 1); + transpose_order.push_back(i + 1 + num_spatial_dims); + } + transpose_order.push_back(feature_dim + num_spatial_dims); + + output_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[1 + i] * block_size_); + } + output_shape.push_back(input_shape[feature_dim] / block_elems); + } else { + // NCHW format. + reshaped_shape.push_back(input_shape[0]); + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(block_size_); + block_elems *= block_size_; + } + reshaped_shape.push_back(input_shape[feature_dim] / block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[2 + i]); + } + + transpose_order.push_back(0); + transpose_order.push_back(1 + num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(2 + num_spatial_dims + i); + transpose_order.push_back(1 + i); + } + + output_shape.push_back(input_shape[0]); + output_shape.push_back(input_shape[feature_dim] / block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[2 + i] * block_size_); + } + } + + // Note: comments are given in NHWC format; NCHW is similar with a different + // dimension order. // 1. Reshape `input` to `reshaped` of shape: // // [batch, @@ -51,14 +123,14 @@ class DepthToSpaceOp : public XlaOpKernel { // block_size_, // block_size_, // depth / (block_size_ * block_size_)] - OP_REQUIRES(ctx, input_shape[3] % (block_size_ * block_size_) == 0, + OP_REQUIRES(ctx, + input_shape[feature_dim] % (block_size_ * block_size_) == 0, errors::InvalidArgument( "Input depth dimension (", input_shape[3], ") is not divisible by square of the block size (", block_size_, ")")); - xla::ComputationDataHandle reshaped = b->Reshape( - input, {input_shape[0], input_shape[1], input_shape[2], block_size_, - block_size_, input_shape[3] / (block_size_ * block_size_)}); + + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -70,7 +142,7 @@ class DepthToSpaceOp : public XlaOpKernel { // block_size_, // depth / (block_size_ * block_size_)] xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -80,15 +152,14 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::ComputationDataHandle output = b->Reshape( - permuted_reshaped, {input_shape[0], input_shape[1] * block_size_, - input_shape[2] * block_size_, - input_shape[3] / (block_size_ * block_size_)}); + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } private: + TensorFormat data_format_; int block_size_; }; REGISTER_XLA_OP(Name("DepthToSpace"), DepthToSpaceOp); diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index ec5017f6ab96bd3fc273a746b77fbb7e74fd9f35..765ea922a532a085a552192348ab360c4c30ff0a 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,6 +24,62 @@ limitations under the License. namespace tensorflow { namespace { +// Create a diagonal / batch diagonal matrix with 'input' on the diagonal. +xla::StatusOr CreateDiagonal( + const xla::ComputationDataHandle& input, int64 last_dim_size, + tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx, + xla::ComputationBuilder* builder) { + // Create two matrices that have the following forms, and compare them: + // + // [[0, 0, 0, 0] [[0, 1, 2, 3] + // [1, 1, 1, 1] [0, 1, 2, 3] + // [2, 2, 2, 2] [0, 1, 2, 3] + // [3, 3, 3, 3]] [0, 1, 2, 3]] + // + // This produces a predicate matrix of the right size, with "true" on the + // diagonal. + xla::ComputationDataHandle iota; + TF_RETURN_IF_ERROR( + XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); + xla::ComputationDataHandle iota_broadcast = + builder->Broadcast(iota, {last_dim_size}); + xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0}); + + // If this is a batched diagonal, broadcast the mask across the other + // dimensions. + if (!other_dims.empty()) { + mask = builder->Broadcast(mask, other_dims); + } + + // Broadcast the input, and then use the mask computed above to select the + // diagonal: + // e.g, in 2D: + // [[t, f, f] [[1, 1, 1] [[0, 0, 0] [[1, 0, 0] + // select( [f, t, f] , [4, 4, 4] , [0, 0, 0] ) = [0, 4, 0] + // [f, f, t]] [9, 9, 9]] [0, 0, 0]] [0, 0, 9]] + // + // Broadcasting the input is less-than-trivial, since we need to broadcast + // into a "middle" dimension. We can do this with a reshape + implicit + // broadcast. + // TODO(b/30112114): Replace with in-dim broadcast when those are supported. + std::vector broadcast_dims(other_dims.begin(), other_dims.end()); + broadcast_dims.push_back(1LL); + broadcast_dims.push_back(last_dim_size); + xla::ComputationDataHandle input_broadcast = + builder->Reshape(input, broadcast_dims); + + broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; + xla::PrimitiveType element_type; + TF_RETURN_IF_ERROR( + DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); + auto broadcast_shape = + xla::ShapeUtil::MakeShape(element_type, broadcast_dims); + xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape); + + input_broadcast = builder->Add(input_broadcast, zeros); + return builder->Select(mask, input_broadcast, zeros); +} + class DiagOp : public XlaOpKernel { public: explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -29,6 +87,8 @@ class DiagOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); + OP_REQUIRES(ctx, ctx->num_inputs() >= 1, + errors::InvalidArgument("Diag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -36,7 +96,7 @@ class DiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle diag = ctx->Input(0); + xla::ComputationDataHandle input = ctx->Input(0); // Picture: // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0] @@ -46,13 +106,13 @@ class DiagOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - diag = builder->Reshape(diag, {size}); + input = builder->Reshape(input, {size}); - // Adds inter-element padding of 'size'. - xla::PaddingConfig config; - auto* dim = config.add_dimensions(); - dim->set_interior_padding(size); - diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); + // Create an R2 with the R1 diagonal. + auto diag_or_status = + CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); + OP_REQUIRES_OK(ctx, diag_or_status.status()); + xla::ComputationDataHandle diag = diag_or_status.ValueOrDie(); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); @@ -141,6 +201,8 @@ class MatrixDiagOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); + OP_REQUIRES(ctx, ctx->num_inputs() >= 1, + errors::InvalidArgument("MatrixDiag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -152,17 +214,13 @@ class MatrixDiagOp : public XlaOpKernel { int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); + tensorflow::gtl::ArraySlice other_dims(dims); + other_dims.pop_back(); - // Adds inter-element padding of 'last_dim_size' to the last dimension. - xla::PaddingConfig config = xla::MakeNoPaddingConfig(dims.size()); - auto* dim = config.mutable_dimensions(last_dim); - dim->set_interior_padding(last_dim_size); - diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); - - // Reshapes to the final shape. - dims.push_back(last_dim_size); - diag = builder->Reshape(diag, dims); - + auto diag_or_status = + CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder); + OP_REQUIRES_OK(ctx, diag_or_status.status()); + diag = diag_or_status.ValueOrDie(); ctx->SetOutput(0, diag); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..d91ebb500b4479dbb3c8e2ea7719bc79dc24ba4f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -0,0 +1,367 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { +namespace { + +// We implement bilinear interpolation by upsampling followed by convolution. +// The basic idea is as follows. To scale from NxN to RxR: +// +// 1. S := (N - 1) / gcd(N-1, R-1) +// 2. k := (R - 1) / gcd(N-1, R-1) +// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1) +// +// For example, to Scale from 7x7 -> 15x15: +// +// 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3 +// 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7 +// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2) +// +// +// The 7x7 -> 15x15 case is much too large to write out in full as an +// example. The smallest interesting example is 3x3 -> 4x4. +// +// S := 2 +// k := 3 +// +// 00 03 06 00 00 00 00 00 00 00 00 00 00 00 00 02 04 06 +// 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00 -> 06 08 10 12 +// 18 21 24 00 00 00 00 00 03 00 00 06 00 00 12 14 16 18 +// 00 00 00 00 00 00 00 00 00 00 00 18 20 22 24 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 09 00 00 12 00 00 15 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 18 00 00 21 00 00 24 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// +// with the following convolutional kernel, with stride [2, 2]: +// 1 2 3 2 1 +// 2 4 6 4 2 +// 1/9 * 3 6 9 6 3 +// 2 4 6 4 2 +// 1 2 3 2 1 + +// Computes the size of the convolutional kernel and stride to use when resizing +// from in_size to out_size. +struct ResizeConvolutionDims { + // Size of the kernel to use. + std::vector kernel_size; + + // Stride of the convolution to use. + std::vector stride; +}; +ResizeConvolutionDims ComputeResizeConvolutionParameters( + gtl::ArraySlice in_size, gtl::ArraySlice out_size) { + CHECK_EQ(in_size.size(), out_size.size()); + int num_spatial_dims = in_size.size(); + ResizeConvolutionDims dims; + dims.kernel_size.resize(num_spatial_dims); + dims.stride.resize(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1) { + // We must handle input size 1 specially because XLA convolution does + // not allow stride 0. + dims.stride[i] = dims.kernel_size[i] = 1; + } else if (out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + dims.stride[i] = dims.kernel_size[i] = 1; + } else { + int64 gcd = MathUtil::GCD(static_cast(in_size[i] - 1), + static_cast(out_size[i] - 1)); + dims.stride[i] = (in_size[i] - 1) / gcd; + dims.kernel_size[i] = (out_size[i] - 1) / gcd; + } + } + return dims; +} + +xla::ComputationDataHandle MakeBilinearResizeKernel( + xla::ComputationBuilder* builder, gtl::ArraySlice kernel_size, + int64 channels) { + // Form a 2D convolution kernel like: + // 1 2 3 2 1 + // 2 4 6 4 2 + // 1/9 * 3 6 9 6 3 + // 2 4 6 4 2 + // 1 2 3 2 1 + // by multiplying two 1D kernels of the form: + // 1/3 * [1 2 3 2 1] + auto make_1d_kernel = [](int64 n) { + std::vector kernel(n * 2 - 1); + for (int64 i = 0; i < n; ++i) { + float v = i + 1; + kernel[i] = v; + kernel[n * 2 - 2 - i] = v; + } + return kernel; + }; + + // Form a block diagonal kernel where each channel interacts only with itself. + xla::Array4D diag(1, 1, channels, channels, 0.0f); + for (int i = 0; i < channels; ++i) { + diag(0, 0, i, i) = 1.0f / (kernel_size[0] * kernel_size[1]); + } + return builder->Mul( + builder->ConstantR1(make_1d_kernel(kernel_size[0])), + builder->Mul(builder->ConstantR1(make_1d_kernel(kernel_size[1])), + builder->ConstantR4FromArray4D(diag), + /*broadcast_dimensions=*/{1}), + /*broadcast_dimensions=*/{0}); +} + +class ResizeBilinearOp : public XlaOpKernel { + public: + explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented( + "ResizeBilinear with align_corners=False is not yet implemented")); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + const int64 batch = input_shape.dim_size(0); + const std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + std::vector out_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); + OP_REQUIRES(ctx, out_size.size() == 2, + errors::InvalidArgument("output size must be length 2, got ", + out_size.size())); + OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, + errors::InvalidArgument("output size must be positive, got [", + out_size[0], ",", out_size[1], "]")); + + const int num_spatial_dims = 2; + + xla::ComputationDataHandle input = ctx->Input(0); + + // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in + // dimension i. + std::vector slice_size = in_size; + bool slice_input = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + slice_input = true; + slice_size[i] = 1; + } + } + if (slice_input) { + input = b->Slice(input, {0, 0, 0, 0}, + {batch, slice_size[0], slice_size[1], channels}, + {1, 1, 1, 1}); + } + + // Output is always type float. + input = b->ConvertElementType(input, xla::F32); + + // Picture for a 1x3 to 1x4 resize: + // stride = 2, kernel size = 3 + // Input: + // 3 6 9 + // Input with dilation and padding: + // 0 0 3 0 0 6 0 0 9 0 0 + // Convolution kernel: + // 1/3 * [1 2 3 2 1] + // Output: + // 3 5 7 9 + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + for (int i = 0; i < num_spatial_dims; ++i) { + dnums.add_input_spatial_dimensions(1 + i); + dnums.add_output_spatial_dimensions(1 + i); + dnums.add_kernel_spatial_dimensions(i); + } + dnums.set_kernel_input_feature_dimension(num_spatial_dims); + dnums.set_kernel_output_feature_dimension(num_spatial_dims + 1); + + ResizeConvolutionDims dims = + ComputeResizeConvolutionParameters(in_size, out_size); + xla::ComputationDataHandle kernel = + MakeBilinearResizeKernel(b, dims.kernel_size, channels); + xla::ComputationDataHandle output = b->ConvGeneralDilated( + input, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dnums); + + // Add broadcasts to handle expanding from a size == 1 dimension to a + // size > 1 dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && out_size[i] > 1) { + output = b->Add(output, b->ConstantR1(out_size[i], 0), + /*broadcast_dimensions=*/{1 + i}); + } + } + + ctx->SetOutput(0, output); + } + + private: + bool align_corners_; +}; + +REGISTER_XLA_OP(Name("ResizeBilinear"), ResizeBilinearOp); + +class ResizeBilinearGradOp : public XlaOpKernel { + public: + explicit ResizeBilinearGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented("ResizeBilinearGrad with align_corners=False is " + "not yet implemented")); + + DataType output_dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + const int64 batch = input_shape.dim_size(0); + const std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + TensorShape grad_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, grad_shape.dims() == 4, + errors::InvalidArgument("gradient must be 4-dimensional", + grad_shape.DebugString())); + const int64 grad_batch = grad_shape.dim_size(0); + const std::vector grad_size = {grad_shape.dim_size(1), + grad_shape.dim_size(2)}; + const int64 grad_channels = grad_shape.dim_size(3); + OP_REQUIRES(ctx, batch == grad_batch, + errors::InvalidArgument( + "activations and gradients must have the same batch size (", + batch, " vs. ", grad_batch, ")")); + OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0, + errors::InvalidArgument("gradient size must be positive, got [", + grad_size[0], ",", grad_size[1], "]")); + OP_REQUIRES( + ctx, channels == grad_channels, + errors::InvalidArgument( + "activations and gradients must have the same number of channels (", + channels, " vs. ", grad_channels, ")")); + + const int num_spatial_dims = 2; + + xla::ComputationDataHandle grad = ctx->Input(0); + + ResizeConvolutionDims dims = + ComputeResizeConvolutionParameters(in_size, grad_size); + + // To form the backward convolution, we keep the kernel unchanged (it is + // already symmetric) and swap the roles of strides and LHS dilation. + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + for (int i = 0; i < num_spatial_dims; ++i) { + dnums.add_input_spatial_dimensions(1 + i); + dnums.add_output_spatial_dimensions(1 + i); + dnums.add_kernel_spatial_dimensions(i); + } + dnums.set_kernel_input_feature_dimension(num_spatial_dims); + dnums.set_kernel_output_feature_dimension(num_spatial_dims + 1); + xla::ComputationDataHandle kernel = + MakeBilinearResizeKernel(b, dims.kernel_size, channels); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && grad_size[i] > 1) { + kernel = b->Add(kernel, b->ConstantR1(grad_size[i], 0), + /*broadcast_dimensions=*/{i}); + } + } + + xla::ComputationDataHandle output = b->ConvGeneralDilated( + grad, kernel, /*window_strides=*/dims.kernel_size, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.stride, + /*rhs_dilation=*/{1, 1}, dnums); + + // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. + // Opposite of the slice performed by the forward op. + xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4); + bool pad_output = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && grad_size[i] == 1) { + pad_output = true; + padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - + 1); + } + } + if (pad_output) { + output = b->Pad(output, b->ConstantR0(0.0f), padding); + } + + output = b->ConvertElementType(output, output_type_); + ctx->SetOutput(0, output); + } + + private: + bool align_corners_; + xla::PrimitiveType output_type_; +}; + +REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index fcef497e5845d9080bc83b54e92dcf2fdecf5f12..644abd5905c6ce5a8f61792a1986560bab891040 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -23,8 +23,8 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array kMatmulTypes = { - {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; class MatMulOp : public XlaOpKernel { public: @@ -85,10 +85,7 @@ class SparseMatMulOp : public MatMulOp { ~SparseMatMulOp() override = default; }; -REGISTER_XLA_OP(Name("SparseMatMul") - .TypeConstraint("Ta", kFloatTypes) - .TypeConstraint("Tb", kFloatTypes), - SparseMatMulOp); +REGISTER_XLA_OP(Name("SparseMatMul"), SparseMatMulOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..650f8c7dc8be0cb08997ec641ca3f82352166fdd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// TODO(phawkins): implement double-sized windowed reductions in XLA and remove +// the type constraint. +constexpr std::array kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT}}; + +class ScanOp : public XlaOpKernel { + public: + ScanOp(OpKernelConstruction* ctx, bool sum) : XlaOpKernel(ctx), sum_(sum) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape tensor_axis_shape = ctx->InputShape(1); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis_shape), + errors::InvalidArgument("ScanOp: axis must be a scalar, not ", + tensor_axis_shape.DebugString())); + + int64 axis; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &axis)); + if (axis < 0) { + axis += input_shape.dims(); + } + OP_REQUIRES( + ctx, FastBoundsCheck(axis, input_shape.dims()), + errors::InvalidArgument("ScanOp: Expected scan axis in the range [", + -input_shape.dims(), ", ", input_shape.dims(), + "), but got ", axis)); + + DataType dtype = ctx->input_type(0); + + if (input_shape.num_elements() == 0) { + // Exit early if there is nothing to compute. + ctx->SetOutput(0, ctx->Input(0)); + return; + } + + xla::ComputationBuilder* builder = ctx->builder(); + + std::vector window_strides(input_shape.dims(), 1); + std::vector window_dims(input_shape.dims(), 1); + window_dims[axis] = input_shape.dim_size(axis); + + std::vector> padding(input_shape.dims(), {0, 0}); + padding[axis].first = input_shape.dim_size(axis) - 1; + // In exclusive mode, add an extra padding element so there is a complete + // window of padding before the data starts. + if (exclusive_) { + ++padding[axis].first; + } + if (reverse_) { + std::swap(padding[axis].first, padding[axis].second); + } + + xla::ComputationDataHandle input = ctx->Input(0); + xla::ComputationDataHandle init; + const xla::Computation* reducer; + if (sum_) { + init = XlaHelpers::Zero(builder, dtype); + reducer = ctx->GetOrCreateAdd(dtype); + } else { + init = XlaHelpers::One(builder, dtype); + reducer = ctx->GetOrCreateMul(dtype); + } + auto output = builder->ReduceWindowWithGeneralPadding( + ctx->Input(0), init, *reducer, window_dims, window_strides, padding); + + // In exclusive mode, we have computed an extra element containing the sum + // of all the input elements. Slice off this extra "last" element. + if (exclusive_) { + if (reverse_) { + output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1, + 1, axis); + + } else { + output = + builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); + } + } + ctx->SetOutput(0, output); + } + + private: + const bool sum_; // True=cumulative sum. False=cumulative product. + bool reverse_; + bool exclusive_; +}; + +class CumsumOp : public ScanOp { + public: + explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {} +}; +REGISTER_XLA_OP(Name("Cumsum").TypeConstraint("T", kScanOpTypes), CumsumOp); + +class CumprodOp : public ScanOp { + public: + explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {} +}; +REGISTER_XLA_OP(Name("Cumprod").TypeConstraint("T", kScanOpTypes), CumprodOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 24a99f253d6dc8bb699fff587c363b12c227e821..e205fadd2b1bcae96a7bfa1bc83096d405ce22c4 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Shape Ops. +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -27,56 +28,42 @@ namespace { class ShapeOp : public XlaOpKernel { public: - explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); - const int rank = input_shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({rank})); - auto vec = shape_constant.vec(); - // TODO(dga): support int64. b/28119922. - for (int i = 0; i < rank; ++i) { - int64 dim_size = input_shape.dim_size(i); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but dim ", i, " is ", dim_size)); - vec(i) = static_cast(dim_size); - } - + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(0, shape_constant); } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("Shape"), ShapeOp); class ShapeNOp : public XlaOpKernel { public: - explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { for (int i = 0; i < ctx->num_inputs(); ++i) { - const TensorShape shape = ctx->InputShape(i); - const int dims = shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({dims})); - auto vec = shape_constant.vec(); - - // TODO(dga): support int64. b/28119922. - for (int j = 0; j < dims; ++j) { - int64 dim_size = shape.dim_size(j); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but shape ", i, " dim ", j, " is ", - dim_size)); - vec(j) = static_cast(dim_size); - } - + const TensorShape input_shape = ctx->InputShape(i); + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(i, shape_constant); } } bool IsExpensive() override { return false; } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..76ea5f525598f511f295eb5a30f3cf603fbf57aa --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" + +#include + +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { + +Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant) { + const int dims = input_shape.dims(); + if (shape_constant->dtype() == DT_INT32) { + auto vec = shape_constant->vec(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + if (!FastBoundsCheck(dim_size, std::numeric_limits::max())) { + return errors::InvalidArgument( + "Shape with out_type=int32 does not support tensors > int32max", + " but dim ", i, " is ", dim_size); + } + vec(i) = static_cast(dim_size); + } + } else { + auto vec = shape_constant->vec(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + vec(i) = dim_size; + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.h b/tensorflow/compiler/tf2xla/kernels/shape_util.h new file mode 100644 index 0000000000000000000000000000000000000000..575086e118080f6799a54d3ae6409b2b641c4341 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Converts a TensorShape to a constant Tensor. +// +// The input TensorShape input_shape is used to populate the elements of +// shape_constant, which is modified in place. +Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 89befda346ec06fec23ab1d1c9d910ded8cd806d..806fda632cde64c1b37ae3b9199028d6b6b0a215 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -23,6 +24,16 @@ namespace { class SpaceToDepthOp : public XlaOpKernel { public: explicit SpaceToDepthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_), + "; expected formats NHWC or NCHW")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -31,34 +42,100 @@ class SpaceToDepthOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_tensor_shape = ctx->InputShape(0); - // The input is presumed to be [batch, height, width, depth] int input_rank = input_tensor_shape.dims(); static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, - errors::InvalidArgument("Input rank should be: ", kRequiredDims, - " instead of: ", input_rank)); + errors::InvalidArgument("Input rank should be ", kRequiredDims, + "; got ", input_rank)); const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle input = ctx->Input(0); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + + std::vector reshaped_shape; + std::vector transpose_order; + std::vector output_shape; + reshaped_shape.reserve(input_rank); + transpose_order.reserve(input_rank); + output_shape.reserve(input_rank); + if (data_format_ == FORMAT_NHWC) { + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, + errors::InvalidArgument( + "input shape[", 1 + i, "]=", input_shape[1 + i], + " is not divisible by block_size=", block_size_)); + block_elems *= block_size_; + } + + reshaped_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[1 + i] / block_size_); + reshaped_shape.push_back(block_size_); + } + reshaped_shape.push_back(input_shape[feature_dim]); + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 1); + } + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 2); + } + transpose_order.push_back(feature_dim + num_spatial_dims); + + output_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[1 + i] / block_size_); + } + output_shape.push_back(input_shape[feature_dim] * block_elems); + } else { + // FORMAT_NCHW + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + OP_REQUIRES(ctx, input_shape[2 + i] % block_size_ == 0, + errors::InvalidArgument( + "input shape[", 2 + i, "]=", input_shape[2 + i], + " is not divisible by block_size=", block_size_)); + block_elems *= block_size_; + } + + reshaped_shape.push_back(input_shape[0]); + reshaped_shape.push_back(input_shape[feature_dim]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[2 + i] / block_size_); + reshaped_shape.push_back(block_size_); + } + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 3); + } + transpose_order.push_back(feature_dim); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 2); + } + + output_shape.push_back(input_shape[0]); + output_shape.push_back(input_shape[feature_dim] * block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[2 + i] / block_size_); + } + } + + // Note: comments are given in NHWC format; NCHW is similar with a different + // dimension order. // 1. Reshape `input` to `reshaped` of shape: // // [batch, // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - const int block_rank = 2; - for (int i = 0; i < block_rank; ++i) { - OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, - errors::InvalidArgument( - "input shape[", 1 + i, "]=", input_shape[1 + i], - " is not divisible by block_size=", block_size_)); - } - xla::ComputationDataHandle reshaped = b->Reshape( - input, {input_shape[0], input_shape[1] / block_size_, block_size_, - input_shape[2] / block_size_, block_size_, input_shape[3]}); + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -69,7 +146,7 @@ class SpaceToDepthOp : public XlaOpKernel { // block_size_, block_size_, // depth] xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -79,15 +156,14 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::ComputationDataHandle output = b->Reshape( - permuted_reshaped, {input_shape[0], input_shape[1] / block_size_, - input_shape[2] / block_size_, - block_size_ * block_size_ * input_shape[3]}); + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } private: + TensorFormat data_format_; int block_size_; }; REGISTER_XLA_OP(Name("SpaceToDepth"), SpaceToDepthOp); diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..b10880de77e6b9811008076cd4a959c284e558d1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -0,0 +1,279 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { +namespace { + +// Rotates a 32-bit integer 'v' left by 'distance' bits. +xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& v, + int distance) { + return builder->Or( + builder->ShiftLeft(v, builder->ConstantR0(distance)), + builder->ShiftRightLogical(v, builder->ConstantR0(32 - distance))); +} + +// TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than +// building XOR out of other bitwise operators. +xla::ComputationDataHandle BitwiseXor(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& y) { + return builder->Or(builder->And(x, builder->Not(y)), + builder->And(builder->Not(x), y)); +} + +using ThreeFry2x32State = std::array; + +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder, + ThreeFry2x32State input, ThreeFry2x32State key) { + // Rotation distances specified by the Threefry2x32 algorithm. + constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24}; + ThreeFry2x32State x; + + std::array ks; + // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. + ks[2] = builder->ConstantR0(0x1BD11BDA); + for (int i = 0; i < 2; ++i) { + ks[i] = key[i]; + x[i] = input[i]; + ks[2] = BitwiseXor(builder, ks[2], key[i]); + } + + x[0] = builder->Add(x[0], ks[0]); + x[1] = builder->Add(x[1], ks[1]); + + // Performs a single round of the Threefry2x32 algorithm, with a rotation + // amount 'rotation'. + auto round = [builder](ThreeFry2x32State v, int rotation) { + v[0] = builder->Add(v[0], v[1]); + v[1] = RotateLeftS32(builder, v[1], rotation); + v[1] = BitwiseXor(builder, v[0], v[1]); + return v; + }; + + // There are no known statistical flaws with 13 rounds of Threefry2x32. + // We are conservative and use 20 rounds. + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = builder->Add(x[0], ks[1]); + x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(1)); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = builder->Add(x[0], ks[2]); + x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(2)); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = builder->Add(x[0], ks[0]); + x[1] = builder->Add(builder->Add(x[1], ks[1]), builder->ConstantR0(3)); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = builder->Add(x[0], ks[1]); + x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(4)); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = builder->Add(x[0], ks[2]); + x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(5)); + + return x; +} + +// Returns a tensor of 'shape' random values uniformly distributed in the range +// [minval, maxval) +xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& seed, + const TensorShape& shape, + double minval, double maxval) { + // Split the seed into two 32-bit scalars to form a key. + auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {}); + ThreeFry2x32State key = {seed0, seed1}; + const int64 size = shape.num_elements(); + + const int64 half_size = MathUtil::CeilOfRatio(size, 2); + const bool size_is_odd = (half_size * 2 != size); + + // Fill the generator inputs with unique counter values. + ThreeFry2x32State inputs; + TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0])); + inputs[1] = builder->Add(inputs[0], builder->ConstantR0(half_size)); + ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); + + if (size_is_odd) { + outputs[1] = builder->Slice(outputs[1], {0}, {half_size - 1}, {1}); + } + + auto bits = + builder->Reshape(builder->ConcatInDim(outputs, 0), shape.dim_sizes()); + + // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit + // forces the random bits into the mantissa. + constexpr int kFloatBits = 32; + constexpr int kMantissaBits = 23; + bits = builder->Or( + builder->ShiftRightLogical( + bits, builder->ConstantR0(kFloatBits - kMantissaBits)), + builder->ConstantR0(bit_cast(1.0f))); + auto floats = builder->BitcastConvertType(bits, xla::F32); + + // We have a floating point number in the range [1.0, 2.0). + // Subtract 1.0f to shift to the range [0.0, 1.0) + floats = builder->Sub(floats, builder->ConstantR0(1.0f)); + // Multiply and add to shift to the range [minval, maxval). + floats = builder->Mul(floats, builder->ConstantR0(maxval - minval)); + floats = builder->Add(floats, builder->ConstantR0(minval)); + return floats; +} + +// Approximation for the inverse error function from +// Giles, M., "Approximating the erfinv function". +// The approximation has the form: +// w = -log((1 - x) * (1 + x)) +// if ( w < 5 ) { +// w = w - 2.5 +// p = sum_{i=1}^n lq[i]*w^i +// } else { +// w = sqrt(w) - 3 +// p = sum_{i=1}^n gq[i]*w^i +// } +// return p*x +xla::ComputationDataHandle ErfInvF32(xla::ComputationBuilder* b, + const xla::ComputationDataHandle& x, + const TensorShape& shape) { + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = b->ConstantR0(1.0); + auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); + + auto lt = b->Lt(w, b->ConstantR0(5.0)); + auto coefficient = [&](int i) { + return b->Select( + lt, + b->Broadcast(b->ConstantR0(w_less_than_5_constants[i]), + shape.dim_sizes()), + b->Broadcast(b->ConstantR0(w_greater_than_5_constants[i]), + shape.dim_sizes())); + }; + w = b->Select(lt, b->Sub(w, b->ConstantR0(2.5f)), + b->Sub(b->SqrtF32(w), b->ConstantR0(3.0f))); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = b->Add(coefficient(i), b->Mul(p, w)); + } + return b->Mul(p, x); +} + +} // namespace + +class StatelessRandomUniformOp : public XlaOpKernel { + public: + explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + TensorShape seed_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + xla::ComputationDataHandle seed = ctx->Input(1); + ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp); +}; + +// TODO(phawkins): generalize to non-float, non-int32 seed types. +REGISTER_XLA_OP(Name("StatelessRandomUniform") + .TypeConstraint("dtype", DT_FLOAT) + .TypeConstraint("Tseed", DT_INT32), + StatelessRandomUniformOp); + +class StatelessRandomNormalOp : public XlaOpKernel { + public: + explicit StatelessRandomNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + TensorShape seed_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, seed_shape == TensorShape({2}), + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + xla::ComputationDataHandle seed = ctx->Input(1); + xla::ComputationBuilder* builder = ctx->builder(); + auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0); + // Convert uniform distribution to normal distribution by computing + // sqrt(2) * erfinv(x) + auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), + ErfInvF32(builder, uniform, shape)); + ctx->SetOutput(0, normal); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp); +}; + +// TODO(phawkins): generalize to non-float, non-int32 seed types. +REGISTER_XLA_OP(Name("StatelessRandomNormal") + .TypeConstraint("dtype", DT_FLOAT) + .TypeConstraint("Tseed", DT_INT32), + StatelessRandomNormalOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 351fda251798e43b607fb445f2c98abd57b3d86b..03c22354a9425189e6cf7ee5a7201c90ecb1908d 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -311,6 +311,32 @@ class TensorArrayGatherOp : public XlaOpKernel { xla::ComputationDataHandle ta = resource->value; + // Look for the case where the gather takes a simple slice from the + // tensor array (0, 1, 2, 3, 4, ..., N) + std::vector const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok()) { + bool gather_is_dense_slice = true; + for (auto i = 0; i < const_indices.size(); i++) { + if (const_indices[i] != i) { + gather_is_dense_slice = false; + break; + } + } + + if (gather_is_dense_slice) { + std::vector begin(ta_shape.dims(), 0); + std::vector strides(ta_shape.dims(), 1); + std::vector end(ta_shape.dims(), 1); + end[0] = const_indices.size(); + for (auto i = 1; i < ta_shape.dims(); i++) { + end[i] = ta_shape.dim_size(i); + } + ctx->SetOutput(0, b->Slice(ta, begin, end, strides)); + return; + } + } + xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b); ctx->SetOutput(0, gather); @@ -352,28 +378,47 @@ class TensorArrayScatterOp : public XlaOpKernel { const xla::ComputationDataHandle value = ctx->Input(2); const xla::ComputationDataHandle flow = ctx->Input(3); - auto slice_dims = value_shape.dim_sizes(); - slice_dims[0] = 1LL; - - std::vector value_starts(value_shape.dims(), 0); - auto value_ends = value_shape.dim_sizes(); - - std::vector value_strides(value_shape.dims(), 1); - - // For every (index, value) pair, update the corresponding TensorArray - // storage. - for (int i = 0; i < num_indices; ++i) { - // Slice out part of the value. - value_starts[0] = i; - value_ends[0] = i + 1; - auto slice = b->Slice(value, value_starts, value_ends, value_strides); + // Look for the case where the scatter is for each sub-tensor in order. The + // tensor array implementation allows for this to be a straight addition. + bool scatter_all_elements_in_order = false; + std::vector const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok() && num_indices == value_shape.dim_size(0)) { + scatter_all_elements_in_order = true; + for (auto i = 0; i < num_indices; i++) { + if (const_indices[i] != i) { + scatter_all_elements_in_order = false; + break; + } + } + } - // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = b->Slice(indices, {i}, {i + 1}, {1}); - auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); - ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + if (scatter_all_elements_in_order) { + ta = b->Add(ta, value); + } else { + auto slice_dims = value_shape.dim_sizes(); + slice_dims[0] = 1LL; + + std::vector value_starts(value_shape.dims(), 0); + auto value_ends = value_shape.dim_sizes(); + + std::vector value_strides(value_shape.dims(), 1); + + // For every (index, value) pair, update the corresponding TensorArray + // storage. + for (int i = 0; i < num_indices; ++i) { + // Slice out part of the value. + value_starts[0] = i; + value_ends[0] = i + 1; + auto slice = b->Slice(value, value_starts, value_ends, value_strides); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto index = b->Slice(indices, {i}, {i + 1}, {1}); + auto start_indices = + b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + } } resource->value = ta; diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index b19ea22f50d2dd44e8d1d81f5930263f364030e1..68847ae7a2cb926edd9d29007e24b0db7fb5a75f 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/no_op.h" namespace tensorflow { @@ -121,5 +123,26 @@ class ResourceGatherOp : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes), ResourceGatherOp); +class VariableShapeOp : public XlaOpKernel { + public: + explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType variable_dtype; + TensorShape shape; + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); + Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); + ctx->SetConstantOutput(0, shape_constant); + } + + private: + DataType out_dtype_; +}; + +REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..21ad21f73737a289390ed1ea767db1078d05b466 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -0,0 +1,120 @@ +# Utilities for building XLA computations. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = ["//tensorflow/compiler/tf2xla:friends"], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") + +cc_library( + name = "batch_dot", + srcs = ["batch_dot.cc"], + hdrs = ["batch_dot.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "cholesky", + srcs = ["cholesky.cc"], + hdrs = ["cholesky.h"], + deps = [ + ":batch_dot", + ":triangular_solve", + ":util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "triangular_solve", + srcs = ["triangular_solve.cc"], + hdrs = ["triangular_solve.h"], + deps = [ + ":batch_dot", + ":util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "triangular_solve_test", + srcs = ["triangular_solve_test.cc"], + deps = [ + ":triangular_solve", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b0e6174475c22e325c090bec5f1d56822e106bc --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" + +#include +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +// The current implementation simply unrolls the computation along the batch +// dimension. +xla::StatusOr BatchDot( + xla::ComputationBuilder* builder, xla::ComputationDataHandle x, + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) { + TF_ASSIGN_OR_RETURN(std::unique_ptr x_shape, + builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(std::unique_ptr y_shape, + builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) { + return errors::InvalidArgument( + "Arguments to BatchedDot have different ranks: ", + xla::ShapeUtil::HumanString(*x_shape), " vs. ", + xla::ShapeUtil::HumanString(*y_shape)); + } + const int ndims = xla::ShapeUtil::Rank(*x_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to BatchedDot must have rank >= 2: ", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape->dimensions(i) != y_shape->dimensions(i)) { + return errors::InvalidArgument( + "Dimension ", i, " of inputs to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(*x_shape), " vs ", + xla::ShapeUtil::HumanString(*y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); + int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); + if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) { + return errors::InvalidArgument( + "Dimensions ", x_inner_dim, " and ", y_inner_dim, + " of arguments to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(*y_shape), + " transpose: ", transpose_y); + } + + // Check for zero lhs/rhs dim size. + if (xla::ShapeUtil::HasZeroElements(*x_shape) || + xla::ShapeUtil::HasZeroElements(*y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); + int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape->dimensions(x_outer_dim)); + dimensions.push_back(y_shape->dimensions(y_outer_dim)); + return builder->Broadcast( + builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())), + dimensions); + } + + if (x_shape->element_type() == xla::C64 && transpose_x) { + x = builder->Conj(x); + } + if (y_shape->element_type() == xla::C64 && transpose_y) { + y = builder->Conj(y); + } + + // If there are no batch dimensions, use a regular Dot. + // TODO(b/69062148) Remove this code when Dot emitters can be passed + // dimensions to transpose directly (i.e. without requiring a Transpose HLO). + if (batch_dimension_numbers.empty()) { + auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; + return builder->Dot(lhs, rhs); + } + + xla::DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); + } + return builder->DotGeneral(x, y, dot_dnums); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h new file mode 100644 index 0000000000000000000000000000000000000000..b46bc7417d29dc5b7e9649ac28cc78b57d4b619c --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" + +namespace tensorflow { + +// Multiplies slices of two tensors in batches. + +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. Each of the +// individual slices can optionally be transposed before multiplication by +// setting the `transpose_x` or `transpose_y` flag to `true`. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if transpose_x else r_x +// c_o = r_y if transpose_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +// TODO(phawkins): add an option to take the complex conjugate of the LHS or +// RHS. +xla::StatusOr BatchDot( + xla::ComputationBuilder* builder, xla::ComputationDataHandle x, + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc new file mode 100644 index 0000000000000000000000000000000000000000..b3cc489adf6042acb3f56b3a0a6c8fbe43bde629 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -0,0 +1,166 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/cholesky.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace { + +// def cholesky_unblocked(a): +// assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1] +// n = a.shape[-2] +// l = np.zeros_like(a) +// for j in xrange(n): +// r = l[..., j, :j] +// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(r, r)) +// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], +// np.transpose(r))) / l[..., j, j] +// return l +xla::StatusOr CholeskyUnblocked( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(a)); + xla::ComputationDataHandle l = Zeros(builder, *shape); + const int64 n = xla::ShapeUtil::GetDimension(*shape, -2); + for (int j = 0; j < n; ++j) { + // Picture of block structure: + // ... \ + // \ + // -- r -- d + // |\ + // B c \ + // | \ + // | ... + // + // ^ + // column j + TF_ASSIGN_OR_RETURN(auto d, + SliceInMinorDims(builder, a, {j, j}, {j + 1, j + 1})); + TF_ASSIGN_OR_RETURN(auto c, + SliceInMinorDims(builder, a, {j + 1, j}, {n, j + 1})); + xla::ComputationDataHandle new_d_squared = d; + xla::ComputationDataHandle br; + if (j > 0) { + TF_ASSIGN_OR_RETURN(auto r, + SliceInMinorDims(builder, l, {j, 0}, {j + 1, j})); + TF_ASSIGN_OR_RETURN(auto b, + SliceInMinorDims(builder, l, {j + 1, 0}, {n, j})); + TF_ASSIGN_OR_RETURN(auto r_squared, + BatchDot(builder, r, r, /*transpose_x=*/false, + /*transpose_y=*/true)); + new_d_squared = builder->Sub(new_d_squared, r_squared); + + TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false, + /*transpose_y=*/true)); + } + auto new_d_inv = builder->Pow( + new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5)); + auto new_d = builder->Mul(new_d_inv, new_d_squared); + TF_ASSIGN_OR_RETURN(l, UpdateSliceInMinorDims(builder, l, new_d, {j, j})); + + if (j > 0) { + c = builder->Sub(c, br); + } + auto new_c = builder->Mul(c, new_d_inv); + TF_ASSIGN_OR_RETURN(l, + UpdateSliceInMinorDims(builder, l, new_c, {j + 1, j})); + } + return l; +} + +} // namespace + +xla::StatusOr Cholesky( + xla::ComputationBuilder* builder, xla::ComputationDataHandle a, + int64 block_size) { + TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, + builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(*a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to Cholesky must have rank >= 2: ", ndims); + } + + const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + return errors::InvalidArgument( + "Arguments to Cholesky must be square matrices: ", + xla::ShapeUtil::HumanString(*a_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to Cholesky must be >= 1; got ", block_size); + } + + // Blocked left-looking Cholesky factorization. + // Algorithm 1 from + // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only + // execution." Proceedings of General Purpose GPUs. ACM, 2017. + xla::ComputationDataHandle l = Zeros(builder, *a_shape); + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + if (i > 0) { + // TODO(phawkins): consider implementing SYRK for the diagonal part of + // the panel. + // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) + TF_ASSIGN_OR_RETURN(auto lhs, + SliceInMinorDims(builder, l, {i, 0}, {n, i})); + TF_ASSIGN_OR_RETURN(auto rhs, + SliceInMinorDims(builder, l, {i, 0}, {i + k, i})); + TF_ASSIGN_OR_RETURN(auto delta, + BatchDot(builder, lhs, rhs, /*transpose_x=*/false, + /*transpose_y=*/true)); + TF_ASSIGN_OR_RETURN(auto before, + SliceInMinorDims(builder, a, {i, i}, {n, i + k})); + TF_ASSIGN_OR_RETURN( + a, UpdateSliceInMinorDims(builder, a, builder->Sub(before, delta), + {i, i})); + } + + // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) + TF_ASSIGN_OR_RETURN(auto x, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x)); + TF_ASSIGN_OR_RETURN(l, + UpdateSliceInMinorDims(builder, l, factorized, {i, i})); + + if (i + k < n) { + // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) + TF_ASSIGN_OR_RETURN(auto panel, + SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); + TF_ASSIGN_OR_RETURN(auto update, + TriangularSolve(builder, factorized, panel, + /*block_size=*/8)); + TF_ASSIGN_OR_RETURN( + l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); + } + } + return l; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h new file mode 100644 index 0000000000000000000000000000000000000000..2bead7359baaf3582c1230adf0cd4a90046859d2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" + +namespace tensorflow { + +// Computes the Cholesky decompositions of a batch of symmetric positive +// definite matrices. +// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the +// two minor dimensions equal. +// The algorithm implements a blocked Cholesky decomposition; `block_size` is +// the block size to use. +// TODO(phawkins): check for negative values on the diagonal and return an +// error, instead of silently yielding NaNs. +xla::StatusOr Cholesky( + xla::ComputationBuilder* builder, xla::ComputationDataHandle a, + int64 block_size = 256); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc new file mode 100644 index 0000000000000000000000000000000000000000..579944c3a381e7018b7fee5013d0509158ce21cc --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -0,0 +1,175 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +xla::StatusOr TriangularSolve( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + xla::ComputationDataHandle b, int64 block_size) { + TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, + builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, + builder->GetShape(b)); + if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have different ranks: ", + xla::ShapeUtil::HumanString(*a_shape), " vs. ", + xla::ShapeUtil::HumanString(*b_shape)); + } + const int ndims = xla::ShapeUtil::Rank(*a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to TriangularSolve must have rank >= 2: ", ndims); + } + // The batch dimensions must be equal. + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape->dimensions(i); + int64 b_size = b_shape->dimensions(i); + if (a_size != b_size) { + return errors::InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal: ", + xla::ShapeUtil::HumanString(*a_shape), " vs ", + xla::ShapeUtil::HumanString(*b_shape)); + } + batch_dimensions.push_back(a_size); + } + + const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); + const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); + if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + return errors::InvalidArgument( + "The 'a' arguments to TriangularSolve must be square matrices: ", + xla::ShapeUtil::HumanString(*a_shape)); + } + if (n != xla::ShapeUtil::GetDimension(*b_shape, -1)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes: ", + xla::ShapeUtil::HumanString(*a_shape), " vs ", + xla::ShapeUtil::HumanString(*b_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got ", + block_size); + } + + // Returns [b1, b2, ... , bn, indices[0], indices[1]]. + auto prepend_batch_dims = [&](std::array indices) { + std::vector output(ndims); + std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); + std::copy(indices.begin(), indices.end(), + output.begin() + batch_dimensions.size()); + return output; + }; + + std::map base_computations; + auto get_base_triangular_solve = + [&](int k) -> xla::StatusOr { + xla::Computation& computation = base_computations[k]; + if (computation.IsNull()) { + std::unique_ptr sub = builder->CreateSubBuilder( + tensorflow::strings::StrCat("trsm_base_", k)); + + auto a_param = + sub->Parameter(0, + xla::ShapeUtil::MakeShape(b_shape->element_type(), + prepend_batch_dims({k, k})), + "a"); + + auto b_param = + sub->Parameter(1, + xla::ShapeUtil::MakeShape(b_shape->element_type(), + prepend_batch_dims({m, k})), + "b"); + + // TODO(phawkins): it might make sense to use a while loop here, rather + // than unrolling. + // TODO(phawkins): the left-looking variant of the algorithm might be more + // efficient at block size 1. + TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, + /*block_size=*/1) + .status()); + + TF_ASSIGN_OR_RETURN(computation, sub->Build()); + } + return &computation; + }; + + xla::ComputationDataHandle output = Zeros(builder, *b_shape); + + // Right-looking blocked triangular solve. + // For an explanation of the algorithm, see the TRSM discussion in: + // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation + // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 + // (2008): 4. + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + + // if k > 1: + // output[..., :, i:i+k] = triangular_solve( + // a[..., i:i+k, ..., i:i+k], b[..., :, i:i+k], side='Right', + // kind='Lower', transpose=True, block_size=1) + // else: + // output[..., :, i] = b[..., :, i] / a[..., i, i] + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {0, i}, {m, i + k})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, a_slice); + } + + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {0, i})); + // b[..., :, i+k:] -= np.dot(output[..., :, i:i+k], + // np.transpose(..., a[i+k:, i:i+k])) + if (i + k < n) { + TF_ASSIGN_OR_RETURN(auto a_slice_2, + SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/true)); + + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, i + k}, {m, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + } + } + return output; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h new file mode 100644 index 0000000000000000000000000000000000000000..501d026411c80359c7efa406ece5929a2e46ac1f --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" + +namespace tensorflow { + +// Solves systems of linear equations with upper or lower triangular matrices by +// backsubstitution. +// +// `a` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form +// square matrices. The strictly upper triangular part of each inner-most matrix +// is assumed to be zero and not accessed. +// `b` is a tensor of shape `[..., M, K]`. +// +// The innermost matrices in the output satisfy matrix equations +// `output[..., i, j] * adjoint(a[..., k, j]) = b[..., i, k]`. +// +// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no +// blocking is used. +// TODO(phawkins): equivalent to the BLAS TRSM routine with side=right, +// kind=lower, and transposed_a=true. Implement the other possible combinations +// of side, kind and transposed_a. +xla::StatusOr TriangularSolve( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + xla::ComputationDataHandle b, int64 block_size = 256); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..671d9aa4fe0c042a3cc44468074653d51c2be75d --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using TriangularSolveTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(TriangularSolveTest, Simple) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::Array2D a_vals({ + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }); + xla::Array2D b_vals({ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + }); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(b_vals, 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 0.08333334, 0.04629629, 0.03367003}, + {2.5, -0.25, -0.1388889, -0.1010101}, + {4.5, -0.58333331, -0.32407406, -0.23569024}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(2e-3, 2e-3)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc new file mode 100644 index 0000000000000000000000000000000000000000..943248aedbdce5e81baa341fdab82fea9a48302d --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -0,0 +1,110 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/util.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, + xla::Shape& shape) { + return builder->Broadcast( + builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), + xla::AsInt64Slice(shape.dimensions())); +} + +xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, double value) { + switch (type) { + case xla::F16: + return builder->ConstantR0(static_cast(value)); + break; + case xla::BF16: + return builder->ConstantR0(static_cast(value)); + break; + case xla::F32: + return builder->ConstantR0(static_cast(value)); + break; + case xla::F64: + return builder->ConstantR0(value); + break; + case xla::C64: + return builder->ConstantR0(value); + break; + default: + LOG(FATAL) << "unhandled element type " << type; + } +} + +xla::StatusOr SliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + gtl::ArraySlice start, gtl::ArraySlice end) { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape->dimensions()), + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return builder->Slice(x, padded_start, padded_end, strides); +} + +xla::StatusOr UpdateSlice( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start) { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + return builder->DynamicUpdateSlice( + x, update, builder->ConstantR1(start_as_int32)); +} + +xla::StatusOr UpdateSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(builder, x, update, padded_start); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h new file mode 100644 index 0000000000000000000000000000000000000000..8fba6b5cf247e9b2c26533c53ece8b0d7d4f4c36 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Returns a zero-filled tensor with shape `shape`. +xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, + xla::Shape& shape); + +// Returns a floating point scalar constant of 'type' with 'value'. +// If 'type' is complex, returns a real value with zero imaginary component. +xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, double value); + +// Performs a slice in the minor dimensions of a Tensor. +xla::StatusOr SliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + gtl::ArraySlice start, gtl::ArraySlice end); + +// Updates a slice of 'x', i.e., +// x[start[0], ..., start[n]] = update +xla::StatusOr UpdateSlice( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start); + +// Updates a slice of 'x', where 'start' contains a list of minor dimensions: +// x[..., start[0], ..., start[n]] = update +xla::StatusOr UpdateSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index d9c839b61019b92b6de3a77a7bec610ae848a9a4..b08a7583cb5ab7efa30a1fa27b973d04992584a7 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -14,34 +14,59 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +namespace { +const char kDeviceSuffixReplicatedCore[] = "REPLICATED_CORE"; +const char kShardingAttribute[] = "_XlaSharding"; +} // namespace -static const char DEVICE_SUFFIX_REPLICATED_CORE[] = "REPLICATED_CORE"; +namespace { +xla::StatusOr> +GetShardingFromNodeDef(const NodeDef& node_def) { + if (!HasNodeAttr(node_def, kShardingAttribute)) { + return tensorflow::gtl::optional(); + } + string value; + xla::OpSharding sharding; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value)); + if (!sharding.ParseFromString(value)) { + return xla::InvalidArgument( + "Experimental _XlaSharding attribute was not a valid encoded " + "xla::OpSharding proto."); + } + return tensorflow::gtl::optional(sharding); +} -static Status CoreOutOfRangeError(int core, int num_cores_per_replica) { +Status CoreOutOfRangeError(int core, int num_cores_per_replica) { return errors::InvalidArgument( "Invalid replicated core id: ", core, "; num_cores_per_replica=", num_cores_per_replica); } +} // namespace xla::StatusOr> -ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) { +ParseShardingFromDevice( + const string& device_name, int num_cores_per_replica, + tensorflow::gtl::optional explicit_sharding) { if (device_name.empty()) { return tensorflow::gtl::optional(); } - DeviceNameUtils::ParsedName parsed_device; if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) { return errors::InvalidArgument("Malformed assigned device '", device_name, "'"); } - if (!parsed_device.has_type || - !StringPiece(parsed_device.type) - .ends_with(DEVICE_SUFFIX_REPLICATED_CORE)) { + + if (explicit_sharding.has_value()) { + return explicit_sharding; + } else if (!parsed_device.has_type || !parsed_device.has_id || + !StringPiece(parsed_device.type) + .contains(kDeviceSuffixReplicatedCore)) { return tensorflow::gtl::optional(); } else { const int core = parsed_device.id; @@ -53,20 +78,34 @@ ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) { } } +xla::StatusOr> +ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica) { + const string& device_name = node_def.device(); + TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + GetShardingFromNodeDef(node_def)); + return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); +} + xla::StatusOr> ParseShardingFromDevice(const Node& node, int num_cores_per_replica) { string device_name = node.assigned_device_name(); if (device_name.empty()) { device_name = node.requested_device(); } - return ParseShardingFromDevice(device_name, num_cores_per_replica); + TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + GetShardingFromNodeDef(node.def())); + return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); } + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { string device_name = src.assigned_device_name(); if (device_name.empty()) { device_name = src.requested_device(); } dst->set_assigned_device_name(device_name); + if (const AttrValue* attr = src.attrs().Find(kShardingAttribute)) { + dst->AddAttr(kShardingAttribute, *attr); + } } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index f6468bba9f950fec88dcc6b3ec760f014d3a0ef3..9e430e30a1247c7d01910b6d57f7c577964e1dd1 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -29,14 +29,21 @@ namespace tensorflow { // - if the device name is invalid. // - the core is parsed and is out of the range [0, num_cores_per_replica). // -// Otherwise, returns either a non-value or a sharding set as per -// xla:ShardingBuilder::AssignDevice. +// Otherwise, returns either: +// - explicit_sharding if explicit_sharding.has_value() +// - a non-value if there is no assigned core or +// - a sharding set as per xla::ShardingBuilder::AssignDevice. xla::StatusOr> -ParseShardingFromDevice(const string& device_name, int num_cores_per_replica); +ParseShardingFromDevice(const string& device_name, int num_cores_per_replica, + tensorflow::gtl::optional + explicit_sharding = tensorflow::gtl::nullopt); xla::StatusOr> ParseShardingFromDevice(const Node& node, int num_cores_per_replica); +xla::StatusOr> +ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica); + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index a14c93a2b9494b89f579bc20ee0510c136f8f01b..906f2290433face4cce3296b2f815d50d8c496ce 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -253,8 +253,7 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context) { + xla::Computation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( @@ -277,7 +276,6 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), "tfcompile", std::move(graph), xla_args, &result)); - *requires_runtime_context = result.requires_runtime_context; *computation = std::move(*result.computation); int num_const_results = 0; @@ -352,12 +350,10 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context) { + xla::Computation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); - TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation, - requires_runtime_context)); + TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index ab99beebf7946237425d4d304a858ac6817177b8..473c431b12d441c652f1d0d6c11c5e87836ab36d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -30,13 +30,9 @@ namespace tensorflow { // // The computation is built in the context of the given `client`, which may // subsequently be used to compile or execute the computation. -// -// If `requires_runtime_context` is filled with true, this indicates the last -// argument of the computation is XlaLocalRuntimeContext*. Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context); + xla::Computation* computation); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..7aca889a266439538c4cd1c153460e6cc871b246 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace tf2xla { +namespace { + +void PrintSupportedOps(const string& device, const string& regen_run) { + XlaOpRegistry::RegisterCompilationKernels(); + + std::vector kdefs = + XlaOpRegistry::DeviceKernels(device, + /*include_compilation_only_kernels=*/true); + std::sort( + kdefs.begin(), kdefs.end(), + [](const KernelDef* a, const KernelDef* b) { return a->op() < b->op(); }); + + std::cout << "**Supported operators for device: " << device << "**\n\n" + << "Operator | Type Constraint\n" + << "-------- | ---------------" << std::endl; + for (const KernelDef* kdef : kdefs) { + std::vector constraints; + for (const KernelDef::AttrConstraint& constraint : kdef->constraint()) { + std::vector types; + for (int type : constraint.allowed_values().list().type()) { + types.push_back(DataTypeString(static_cast(type))); + } + std::sort(types.begin(), types.end()); + constraints.push_back("`" + constraint.name() + "={" + + str_util::Join(types, ",") + "}`"); + } + std::cout << "`" << kdef->op() << "` | " + << str_util::Join(constraints, "
") << std::endl; + } + + std::cout << "\nTo regenerate this table, run:\n\n```shell\n" + << regen_run << " --device=" << device << "\n```" << std::endl; +} + +} // namespace + +void SupportedOpsMain(int argc, char** argv, const char* regen_run) { + std::vector device_names = XlaOpRegistry::BackendNames(); + std::sort(device_names.begin(), device_names.end()); + + // Set up and parse flags. + string device; + std::vector flag_list = { + {"device", &device, + "Name of the compilation device for which to print supported ops, " + "one of: " + + str_util::Join(device_names, ",")}, + }; + string usage = Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + QCHECK(XlaOpRegistry::IsBackendRegistered(device)) + << "\nUnknown device: " << device << "\n" + << usage; + + // Run the program. + port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " + "other than flags\n\n" + << usage; + PrintSupportedOps(device, regen_run); +} + +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..1b45fb4cdd3b0173b04e130b7416874a9a406dc5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ + +namespace tensorflow { +namespace tf2xla { + +// The implementation of a main function for a binary that prints a table of +// supported tf2xla operators for a given device, along with their type +// constraints, to stdout. +// +// Pass the argc and argv from main, unmodified. Use regen_run to specify the +// command used to regenerate the table. +void SupportedOpsMain(int argc, char** argv, const char* regen_run); + +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..690666c2400d45e33c1a5d1818b68a86a70a5be3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc @@ -0,0 +1,22 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h" + +int main(int argc, char** argv) { + const char* regen_run = + "bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops"; + tensorflow::tf2xla::SupportedOpsMain(argc, argv, regen_run); +} diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index ecd15652fe84b0c19d2f7fc18f877236547f9be9..a9978e697b091715ce120f0d18fdddd259e08b32 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -70,10 +70,7 @@ TEST(ConvertGraphDefToXla, Sum) { xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); xla::Computation computation; - bool requires_runtime_context; - TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation, - &requires_runtime_context)); - ASSERT_FALSE(requires_runtime_context); + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. auto x_literal = xla::Literal::CreateR0(10); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 55f2f3149c6ba7bfa18608f961c8a76103a50756..f428a194328935fec1210ea96245344de859e611 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -88,8 +88,8 @@ Status ValidateConfig(const tf2xla::Config& config) { TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names)); } TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names)); - if (config.feed().empty() || config.fetch().empty()) { - return errors::InvalidArgument("feeds and fetches must be specified"); + if (config.fetch().empty()) { + return errors::InvalidArgument("fetches must be specified"); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 436039e154842443f779aba276bc571fc2ab7537..ed10d80609641b090cf78bf2e17364fe2fa89c31 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -58,24 +58,14 @@ TEST(ValidateConfig, Good) { TEST(ValidateConfig, BadEmpty) { tf2xla::Config config; - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); -} - -TEST(ValidateConfig, BadNoFeed) { - tf2xla::Config config; - tf2xla::Fetch* fetch = config.add_fetch(); - fetch->mutable_id()->set_node_name("foo"); - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); + ExpectErrorContains(ValidateConfig(config), "fetches must be specified"); } TEST(ValidateConfig, BadNoFetch) { tf2xla::Config config; tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); + ExpectErrorContains(ValidateConfig(config), "fetches must be specified"); } TEST(ValidateConfig, BadFeedNodeName) { diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 1efbe0ffb17dad5332aa700b2e255d4a99fbef72..c969212a1bfaa6cab0d896ee074cfd4e2b283ae4 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -49,6 +49,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_UINT64: *type = xla::U64; return Status::OK(); + case tensorflow::DT_BFLOAT16: + *type = xla::BF16; + return Status::OK(); case tensorflow::DT_HALF: *type = xla::F16; return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 4f32c29954b2d809d31ef8c584b6a6c3dcdf5cef..cc459dc87c00f19230c65341d53da213e07fe364 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -100,7 +100,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, b->SetOpMetadata(metadata); auto sharding_parse_result = ParseShardingFromDevice( - op_kernel->requested_device(), std::numeric_limits::max()); + op_kernel->def(), std::numeric_limits::max()); OP_REQUIRES_OK(context, sharding_parse_result.status()); tensorflow::gtl::optional op_sharding = sharding_parse_result.ValueOrDie(); diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index b5c17c5273bb15e20184b2fefd93880d4828105e..79da701fd244a461a60588153b601d5c1870fa89 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -28,9 +28,10 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, temps_(new void*[static_data.num_temps]), arg_names_(static_data.arg_names), result_names_(static_data.result_names), - program_shape_(static_data.program_shape) { + program_shape_(static_data.program_shape), + hlo_profile_printer_(static_data.hlo_profile_printer) { // Allocate arg and temp buffers. - if (alloc_mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { + if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( static_data.arg_sizes, static_data.num_args, args_, /*annotate_initialized=*/false); @@ -39,9 +40,13 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, static_data.temp_sizes, static_data.num_temps, temps_, /*annotate_initialized=*/true); - // The runtime context is always the last arg, if it is required. - if (static_data.requires_runtime_context) { - args_[static_data.num_args - 1] = &context_; + // If Hlo profiling is enabled the generated code expects an appropriately + // sized buffer to be passed in as the last argument. If Hlo profiling is + // disabled the last function argument is still present in the function + // signature, but it is ignored by the generated code and we pass in null for + // it. + if (hlo_profiling_enabled()) { + profile_counters_ = new int64[static_data.profile_counters_size](); } } @@ -50,6 +55,7 @@ XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); delete[] args_; delete[] temps_; + delete[] profile_counters_; } namespace { diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index f49a7889222ff989144217ab10b27595f89e4311..e0ae3ed9a811bcc49ce8862037a67d293e879e57 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -16,10 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ -#include +#include #include -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/platform/types.h" @@ -27,6 +26,7 @@ limitations under the License. // never use this functionality. namespace xla { class ProgramShape; +class HloProfilePrinter; } namespace tensorflow { @@ -48,12 +48,10 @@ namespace tensorflow { class XlaCompiledCpuFunction { public: // Type of the raw function, produced by either JIT or AOT. - // - // TODO(toddw): Add support for hlo profiling, and replace std::function with - // a raw function pointer, for some codesize savings. - using RawFunction = std::function; + using RawFunction = void (*)(void* result, + const xla::ExecutableRunOptions* run_options, + const void** args, void** temps, + int64* profile_counters); // StaticData represents the state necessary to run an XLA-compiled // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for @@ -71,9 +69,6 @@ class XlaCompiledCpuFunction { // The 0-based index of the result tuple, in the temp buffers. size_t result_index = 0; - // Is the final arg XlaLocalRuntimeContext? - bool requires_runtime_context = false; - // [Optional] Arrays of arg and result names. These are arrays of C-style // strings, where the array is terminated by nullptr. const char** arg_names = nullptr; @@ -81,21 +76,29 @@ class XlaCompiledCpuFunction { // [Optional] Arg and result shapes. const xla::ProgramShape* program_shape = nullptr; + + // [Optional] Profile printer. Null if profiling is disabled. + const xla::HloProfilePrinter* hlo_profile_printer = nullptr; + + // [Optional] The number of profile counters expected in the profile counter + // buffer by the generated code and hlo_profile_printer. 0 if profiling is + // disabled. + int64 profile_counters_size = 0; }; // AllocMode controls the buffer allocation mode. enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, + // Allocate all buffers - args, results, profile and temps. + ARGS_RESULTS_PROFILES_AND_TEMPS, - // Only allocate result and temp buffers. + // Only allocate result, profile and temp buffers. // Use set_arg_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, + RESULTS_PROFILES_AND_TEMPS_ONLY, }; XlaCompiledCpuFunction( const StaticData& static_data, - AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); virtual ~XlaCompiledCpuFunction(); XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; @@ -104,21 +107,22 @@ class XlaCompiledCpuFunction { // Sets the intra-op thread pool used to run individual ops concurrently. void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { run_options_.set_intra_op_thread_pool(pool); - context_.thread_pool = pool; } // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. bool Run() { - context_.error = false; - context_.error_msg.clear(); raw_function_(temps_[result_index_], &run_options_, - const_cast(args_), temps_); - return !context_.error; + const_cast(args_), temps_, profile_counters_); + return true; } // Returns the error message from the previous failed Run call. - const string& error_msg() const { return context_.error_msg; } + // + // TODO(fschneider): For now this always returns an empty string because there + // is no support for error reporting in XLA. Remove this once all callers are + // updated. + string error_msg() const { return {}; } // ------------------------------ // Arg methods for managing input buffers. Buffers are in row-major order. @@ -141,10 +145,6 @@ class XlaCompiledCpuFunction { // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in // tensorflow/compiler/aot/runtime.h to ensure correct alignment. // - // If StaticData.requires_runtime_context==true, the final argument is an - // XlaLocalRuntimeContext, which is managed internally by this class, and - // should not be changed. - // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. void set_arg_data(size_t index, void* data) { args_[index] = data; } @@ -162,6 +162,16 @@ class XlaCompiledCpuFunction { return static_cast(temps_[result_index_]); } + // Profile counters for this XLA computation. + // + // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in + // this case) these counters are non-null and are automatically populated by + // `Run`. The counters can then be pretty-printed using + // `hlo_profile_printer()`. + // + // When Hlo profiling is disabled, this accessor returns null. + const int64* profile_counters() const { return profile_counters_; } + // Returns the buffer for the positional result at the given `index`. void* result_data(size_t index) { return results()[index]; } const void* result_data(size_t index) const { return results()[index]; } @@ -195,6 +205,12 @@ class XlaCompiledCpuFunction { // program shape isn't available. const xla::ProgramShape* ProgramShape() const { return program_shape_; } + bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } + const xla::HloProfilePrinter& hlo_profile_printer() const { + assert(hlo_profiling_enabled()); + return *hlo_profile_printer_; + } + private: const RawFunction raw_function_; const size_t result_index_; @@ -208,14 +224,17 @@ class XlaCompiledCpuFunction { void* alloc_args_ = nullptr; void* alloc_temps_ = nullptr; + // Backing memory for profiling counters. + int64* profile_counters_ = nullptr; + // Options and context passed to the compiled function. xla::ExecutableRunOptions run_options_; - tensorflow::XlaLocalRuntimeContext context_; // Optional metadata. const char** arg_names_ = nullptr; const char** result_names_ = nullptr; const xla::ProgramShape* program_shape_ = nullptr; + const xla::HloProfilePrinter* hlo_profile_printer_ = nullptr; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 48cebdf74c71f974bf075e0255626ec57eb9a149..4c01e6732128fbb62fb134ad7fa3233725f53ebb 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -543,8 +543,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options.resolve_compile_time_constants); core::ScopedUnref context_unref(context); - result->tuple_arg = options.use_tuple_arg; - std::vector arg_expressions; std::vector arg_cores; TF_RETURN_IF_ERROR(BuildArguments( @@ -564,11 +562,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->resource_updates)); - result->requires_runtime_context = context->has_context_parameter(); - - // Tuple arguments and runtime context parameters are incompatible. - TF_RET_CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); - VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; result->outputs.resize(context->retvals().size()); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 4d40ca5825a0c864c63826c901169607d5080c09..380e24e96bc713af4453f92a5359995e9ab4734a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -54,8 +54,6 @@ namespace tensorflow { // +---------------------+-----------------------------------------+ // Within each block, the arguments are arranged by the _Arg index from which // they were derived. -// If `Options::requires_runtime_context` is true, then an additional runtime -// context argument is passed as a final argument. // // The run-time outputs of the XLA computation are arranged in the following // order: @@ -191,16 +189,9 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector input_mapping; - // Does the computation require the local runtime context to be passed as - // the last argument? - bool requires_runtime_context = false; - // Input shapes of the computation. std::vector xla_input_shapes; - // Should the arguments be packed into a single tuple? - bool tuple_arg; - // Output shape in XLA format. The output shape is always a tuple. xla::Shape xla_output_shape; @@ -232,16 +223,9 @@ class XlaCompiler { int graph_def_version = TF_GRAPH_DEF_VERSION; // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() - // for CPU; additionally, an optional XlaLocalRuntimeContext* may be passed - // to the computation. + // for CPU. bool allow_cpu_custom_calls = false; - // If 'local_executable_has_hybrid_result', the top-level pointers of the - // result tuple of compiled programs are stored in host memory and the - // nested buffers in device memory, otherwise the whole result tuple is - // stored in device memory. - bool local_executable_has_hybrid_result = false; - // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation // device is created, and can be used to create metadata objects diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 651bafd6c5d946adfedd63ebbe93e4ea016f0b37..5d19dd353fc04744e196bb50c35cb60b35d8b258 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -70,24 +70,6 @@ XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants) {} -const xla::ComputationDataHandle& -XlaContext::GetOrCreateRuntimeContextParameter() { - CHECK(allow_cpu_custom_calls_); - if (has_context_parameter_) return context_parameter_; - has_context_parameter_ = true; - - // Allocate the next available parameter for the context parameter. - int num_parameters = 0; - for (const XlaExpression& arg : args_) { - if (!arg.has_constant_value()) { - ++num_parameters; - } - } - context_parameter_ = builder_->Parameter( - num_parameters, xla::ShapeUtil::MakeOpaqueShape(), "tf_context"); - return context_parameter_; -} - string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value @@ -178,6 +160,20 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { }); } +const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { + return LookupOrCreate(type, &mul_func_, [this, type] { + const string type_string = DataTypeString(type); + VLOG(1) << "Building Mul() for " << type_string; + xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">"); + xla::PrimitiveType xla_type; + TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); + auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + b.Mul(x, y); + return b.Build().ConsumeValueOrDie(); + }); +} + const xla::Computation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, const std::function& create) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index de8aafa3628e6eebdabbc508cd95a2ac86e3472f..ebd758d1540eba5483714265565ad22c244ca4a3 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -56,15 +56,10 @@ class XlaContext : public ResourceBase { xla::ComputationBuilder* builder(); bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - bool has_context_parameter() const { return has_context_parameter_; } const std::vector& args() const { return args_; } void set_args(std::vector args); - // Get the runtime context parameter, adding one if it does not already exist. - // Dies if not compiling a local executable. - const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter(); - const std::vector& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value @@ -102,6 +97,11 @@ class XlaContext : public ResourceBase { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Get an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + // The name of the XlaContext resource during symbolic graph execution. static const char kXlaContextResourceName[]; @@ -119,13 +119,6 @@ class XlaContext : public ResourceBase { // run-time computation outptus. const bool resolve_compile_time_constants_; - // When 'has_context_parameter_' is true, this is the computation handle - // for an additional final parameter to the computation, through which will be - // passed a XlaLocalRuntimeContext* at runtime. Created on demand by - // GetOrCreateRuntimeContextParameter(). - bool has_context_parameter_ = false; - xla::ComputationDataHandle context_parameter_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; @@ -155,6 +148,9 @@ class XlaContext : public ResourceBase { // Cached computation to compute Sum of two elements, specialized by type. ComputationMap add_func_; + // Cached computation to compute Mul of two elements, specialized by type. + ComputationMap mul_func_; + // Cached computation to compute Sigmoid of an element, specialized by type. ComputationMap sigmoid_func_; diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index d504613d232c779e47a506657d2825d052e726dc..8ca757e72355d890c13b8b448d35c327d3986696 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -21,8 +21,6 @@ namespace tensorflow { bool GpuOpFilter(KernelDef* kdef) { // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to // slow code. - // TODO(b/34969189) The implementation of TruncatedNormal generates illegal - // code on GPU. if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") { return false; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 1df6173275a95bca66f64b3f6df2db9c7a03580b..ec9e535b707beec6ea26dc81c7ee76b1d4da9225 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines helper routines for Tla JIT compilation. +// This file defines helper routines for XLA compilation. #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -120,6 +121,8 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, DataType data_type) { switch (data_type) { + case DT_BFLOAT16: + return b->ConstantR0(bfloat16::epsilon()); case DT_FLOAT: return b->ConstantR0(std::numeric_limits::epsilon()); case DT_DOUBLE: @@ -168,6 +171,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::S16: case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; + case xla::BF16: + literal = *xla::Literal::CreateR0(static_cast(value)); + break; case xla::F16: literal = *xla::Literal::CreateR0(static_cast(value)); @@ -185,25 +191,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, DataType data_type, double value) { - xla::Literal literal; xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - switch (type) { - case xla::F16: - return b->ConstantR0(static_cast(value)); - break; - case xla::F32: - return b->ConstantR0(static_cast(value)); - break; - case xla::F64: - return b->ConstantR0(value); - break; - case xla::C64: - return b->ConstantR0(value); - break; - default: - LOG(FATAL) << "unhandled element type " << type; - } + return ::tensorflow::FloatLiteral(b, type, value); } /* static */ Status XlaHelpers::ReshapeLiteral( diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 1dd454ea8d57e21526e5bcde0c8efc5514983b93..584417bc72c8f6645c05912e857b031cfb394e54 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -37,27 +37,14 @@ namespace { // Returns a vector of positional argument buffer sizes. xla::StatusOr> ComputeArgSizes( - const xla::ProgramShape& program_shape, bool requires_runtime_context) { + const xla::ProgramShape& program_shape) { std::vector arg_sizes; const size_t num_args = program_shape.parameters_size(); arg_sizes.reserve(num_args); for (int i = 0; i < num_args; ++i) { const xla::Shape& arg_shape = program_shape.parameters(i); - if (i == num_args - 1 && requires_runtime_context) { - // If the compiled function needs an XlaLocalRuntimeContext* arg, it's - // always last, and must be represented as an opaque type. - const xla::PrimitiveType type = arg_shape.element_type(); - if (type != xla::OPAQUE) { - return errors::InvalidArgument( - "expected final context arg to be opaque, but got type: ", - xla::PrimitiveType_Name(type), ", from program shape: ", - xla::ShapeUtil::HumanString(program_shape)); - } - arg_sizes.push_back(-1); - } else { - constexpr size_t kPointerSize = sizeof(void*); - arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); - } + constexpr size_t kPointerSize = sizeof(void*); + arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); } return std::move(arg_sizes); } @@ -90,21 +77,6 @@ xla::StatusOr ComputeResultIndex( return result_slice.index(); } -// Adapt ComputeFunctionType, which includes a final profile_counters arg, to -// RawFunction, which doesn't include that final arg. -// -// TODO(toddw): Change RawFunction and AOT to also pass the final -// profile_counters arg, and remove this adapter. -XlaCompiledCpuFunction::RawFunction RawFunctionAdapter( - xla::cpu::CpuExecutable::ComputeFunctionType compute_function) { - return [compute_function](void* result, - const xla::ExecutableRunOptions* run_options, - const void** args, void** temps) { - return compute_function(result, run_options, args, temps, - /*profile_counters=*/nullptr); - }; -} - // Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold // the actual strings in nonempty_names, and hold arrays of pointers in // name_ptrs, terminated by a nullptr entry. @@ -144,9 +116,8 @@ XlaJitCompiledCpuFunction::Compile( TF_ASSIGN_OR_RETURN(xla::LocalClient * client, xla::ClientLibrary::GetOrCreateLocalClient()); xla::Computation computation; - bool requires_runtime_context; - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla( - graph_def, config, client, &computation, &requires_runtime_context)); + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla(graph_def, config, client, + &computation)); // Get and verify the program shape. TF_ASSIGN_OR_RETURN(std::unique_ptr program_shape, @@ -177,14 +148,13 @@ XlaJitCompiledCpuFunction::Compile( const xla::cpu::CpuExecutable* cpu_executable = static_cast(executable->executable()); XlaCompiledCpuFunction::RawFunction raw_function = - RawFunctionAdapter(cpu_executable->compute_function()); + cpu_executable->compute_function(); const xla::BufferAssignment& buffer_assignment = cpu_executable->buffer_assignment(); // Compute buffer sizes and the result index, needed to run the raw function. - TF_ASSIGN_OR_RETURN( - std::vector arg_sizes, - ComputeArgSizes(*program_shape, requires_runtime_context)); + TF_ASSIGN_OR_RETURN(std::vector arg_sizes, + ComputeArgSizes(*program_shape)); TF_ASSIGN_OR_RETURN(std::vector temp_sizes, ComputeTempSizes(buffer_assignment)); TF_ASSIGN_OR_RETURN(size_t result_index, @@ -203,7 +173,6 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.temp_sizes = jit->temp_sizes_.data(); jit->static_data_.num_temps = jit->temp_sizes_.size(); jit->static_data_.result_index = result_index; - jit->static_data_.requires_runtime_context = requires_runtime_context; // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.fetch(), &jit->nonempty_result_names_, @@ -211,6 +180,14 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.arg_names = jit->arg_names_.data(); jit->static_data_.result_names = jit->result_names_.data(); jit->static_data_.program_shape = jit->program_shape_.get(); + + if (cpu_executable->hlo_profiling_enabled()) { + jit->static_data_.hlo_profile_printer = + &cpu_executable->hlo_profile_printer(); + jit->static_data_.profile_counters_size = + cpu_executable->hlo_profile_printer().profile_counters_size(); + } + return std::move(jit_unique_ptr); } diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h deleted file mode 100644 index dca420d6ee3fec45f88ac3b450ab0cb4fb83d38a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ -#define TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ - -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -// Forward-declare the ThreadPoolDevice so that it can be ignored unless it's -// actually used. E.g. some ahead-of-time compiled computations don't need a -// thread pool. -namespace Eigen { -struct ThreadPoolDevice; -} - -namespace tensorflow { - -// An instance of this class is passed to each call from tensorflow into a -// compiled XLA computation. See xla_launch_ops.cc. -struct XlaLocalRuntimeContext { - public: - XlaLocalRuntimeContext() {} - - // Kernels implemented using custom call ops set this if they encounter an - // error. The error is checked after the entire XLA computation is - // complete. - // - // error+error_msg are used instead of Status to reduce the binary size - // overhead for ahead-of-time compiled binaries. - bool error = false; - string error_msg; - - // Kernels that need a thread pool can get it from here. - const Eigen::ThreadPoolDevice* thread_pool = nullptr; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalRuntimeContext); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index b948dfee6ab33651e52ca5045cfce600c788bc3b..79d501b511bf37ba4a79ab9d375d6f789a36889b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -345,6 +345,16 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { expression->set_constant_value(constant); } +void XlaOpKernelContext::SetInvalidOutput(int index) { + Tensor* output = nullptr; + OP_REQUIRES_OK(context_, + context_->allocate_output(index, TensorShape({}), &output)); + XlaExpression* expression = CastExpressionFromUninitializedTensor(output); + xla::ComputationDataHandle handle; + handle.set_handle(0); + expression->set_handle(handle); +} + void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { Tensor* output = nullptr; // The shape of the output tensor is the shape of the resource itself @@ -407,6 +417,11 @@ const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( return XlaContext::Get(context_).GetOrCreateAdd(type); } +const xla::Computation* XlaOpKernelContext::GetOrCreateMul( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateMul(type); +} + XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} void XlaOpKernel::Compute(OpKernelContext* context) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 5519e89252ca5a3964dcdaaeb3d08ce6c9da6bd4..f1ae81a5aa9d507a3e0dd577568377385b1844e6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -142,6 +142,10 @@ class XlaOpKernelContext { // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); + // Sets output 'index' to an invalid value. + // Any subsequent attempt to consume this output will cause an error. + void SetInvalidOutput(int index); + // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } Status status() { return context_->status(); } @@ -174,7 +178,7 @@ class XlaOpKernelContext { // If this kernel invocation is within a function execution, // call_frame() returns the call frame for the function call. - FunctionCallFrame* call_frame() const { return context_->call_frame(); } + CallFrameInterface* call_frame() const { return context_->call_frame(); } FunctionLibraryRuntime* function_library() const { return context_->function_library(); @@ -206,6 +210,11 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Gets an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + private: OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 02318cf7fa1d4edc12507f6b4d66a8e897cbe100..faf47434b5dc6b569ec4f9c91a8667de275a6315 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -187,22 +188,39 @@ void XlaOpRegistry::RegisterCompilationKernels() { // Constrain each type attribute to the intersection of: // a) the types supported by the backend, and - // b) the attribute's type constraints. - // TODO(phawkins): it may be necessary to also take the intersection with - // the set of types supported by the OpDef. + // b) the types allowed by the OpDef, and + // c) the type constraints. for (const string& type_attr : type_attrs) { KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); attr_constraint->set_name(type_attr); auto* allowed_values = attr_constraint->mutable_allowed_values()->mutable_list(); - auto it = op_registration->type_constraints.find(type_attr); + const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); + const auto* op_def_allowed_types = + op_def_attr.has_allowed_values() + ? &op_def_attr.allowed_values().list().type() + : nullptr; + auto constraint_it = op_registration->type_constraints.find(type_attr); + const std::set* type_constraints = + constraint_it != op_registration->type_constraints.end() + ? &constraint_it->second + : nullptr; for (DataType dtype : backend.second.supported_types) { - if (it == op_registration->type_constraints.end() || - (it != op_registration->type_constraints.end() && - it->second.find(dtype) != it->second.end())) { - allowed_values->add_type(dtype); + // Filter out types that aren't allowed by the OpDef. + if (op_def_allowed_types != nullptr && + std::find(op_def_allowed_types->begin(), + op_def_allowed_types->end(), + dtype) == op_def_allowed_types->end()) { + continue; } + // Filter out types based on the type constraints. + if (type_constraints != nullptr && + type_constraints->find(dtype) == type_constraints->end()) { + continue; + } + // Passed all the filters, this type is allowed. + allowed_values->add_type(dtype); } if (op_registration->allow_resource_types) { allowed_values->add_type(DT_RESOURCE); @@ -245,6 +263,22 @@ std::vector XlaOpRegistry::DeviceKernels( return kernels; } +std::vector XlaOpRegistry::BackendNames() { + std::vector names; + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + for (const auto& backend_pair : registry.backends_) { + names.push_back(backend_pair.first); + } + return names; +} + +bool XlaOpRegistry::IsBackendRegistered(const string& name) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + return registry.backends_.find(name) != registry.backends_.end(); +} + XlaOpRegistry& XlaOpRegistry::Instance() { static XlaOpRegistry* r = new XlaOpRegistry; return *r; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 6aee8c91cc01b4382ef867fa8e438eede008ac73..2959d2ab690dfb91f8f46f5cf5718a405d9e0c7f 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -97,6 +97,12 @@ class XlaOpRegistry { gtl::ArraySlice supported_types, BackendOpFilter op_filter); + // Returns the names of the registered backends. + static std::vector BackendNames(); + + // Returns true iff a backend with the given name is registered. + static bool IsBackendRegistered(const string& name); + // Registers `device_name` for XLA compilation, using information from // `registration`. static void RegisterCompilationDevice(const string& device_name, @@ -116,8 +122,8 @@ class XlaOpRegistry { static void RegisterCompilationKernels(); // Returns KernelDefs for compilation ops registered on - // 'compilation_device_name'. - // Does not include kernels registered as CompilationOnly. + // 'compilation_device_name'. Does not include kernels registered as + // CompilationOnly, iff include_compilation_only_kernels=false. static std::vector DeviceKernels( const string& compilation_device_name, bool include_compilation_only_kernels); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 660f419e464936b01a3644e69c2f056f998140f5..d3f292207fee396fb4248dede5c0eeb5cd2b87c9 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -77,6 +77,7 @@ cc_library( hdrs = ["types.h"], visibility = [":friends"], deps = [ + "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//third_party/eigen3", ], @@ -174,6 +175,7 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", ], ) @@ -339,6 +341,7 @@ cc_library( name = "array", hdrs = ["array.h"], deps = [ + ":status", ":types", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index ba898d1f4e9100df59c6e4b28824895c5ae6c08a..213e0bac6c77e9972de8d4dd7dfc8c7cf3a1b865 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -23,8 +23,10 @@ limitations under the License. #include #include #include +#include #include +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -35,10 +37,63 @@ limitations under the License. namespace xla { +namespace array_impl { + +// conjunction +// +// Performs a compile-time logical AND operation on the passed types (which +// must have `::value` members convertible to `bool`. Short-circuits if it +// encounters any `false` members (and does not compare the `::value` members +// of any remaining arguments). +// +// This metafunction is designed to be a drop-in replacement for the C++17 +// `std::conjunction` metafunction. +template +struct conjunction; + +template +struct conjunction + : std::conditional, T>::type {}; + +template <> +struct conjunction<> : std::true_type {}; + +// A type trait that is valid when all elements in a parameter pack are of +// integral type. +template +using pack_is_integral = conjunction...>; + +// Compares three same-sized vectors elementwise. For each item in `values`, +// returns false if any of values[i] is outside the half-open range [starts[i], +// ends[i]). +template +bool all_inside_range(const C1& values, const C2& range_starts, + const C3& range_ends) { + for (size_t i = 0, e = values.size(); i < e; ++i) { + if (values[i] < range_starts[i] || values[i] >= range_ends[i]) { + return false; + } + } + return true; +} + +} // namespace array_impl + // General N dimensional array class with arbitrary value type. template class Array { public: + // Type inference can have a hard time parsing very deep initializer list + // nests, especially if one or more dimensions is one as the compiler just + // sees a single-element integer initializer. These typedefs allow casting + // explicitly with less typing. + using InitializerList1D = std::initializer_list; + using InitializerList2D = std::initializer_list; + using InitializerList3D = std::initializer_list; + using InitializerList4D = std::initializer_list; + + using value_type = T; + // Creates a new array with the specified dimensions. explicit Array(tensorflow::gtl::ArraySlice sizes) : Array(sizes, T()) {} @@ -53,7 +108,7 @@ class Array { // Creates a 2D array from the given nested initializer list. The outer // initializer list is the first dimension, the inner is the second dimension. // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. - Array(std::initializer_list> values) + Array(InitializerList2D values) : Array(ToInt64Vector({values.size(), values.begin()->size()})) { int64 idx = 0; for (const auto& it1 : values) { @@ -67,8 +122,7 @@ class Array { // Creates a 3D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. - Array(std::initializer_list>> - values) + Array(InitializerList3D values) : Array(ToInt64Vector({values.size(), values.begin()->size(), values.begin()->begin()->size()})) { int64 idx = 0; @@ -85,9 +139,7 @@ class Array { // Creates a 4D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. - Array(std::initializer_list< - std::initializer_list>>> - values) + Array(InitializerList4D values) : Array(ToInt64Vector({values.size(), values.begin()->size(), values.begin()->begin()->size(), values.begin()->begin()->begin()->size()})) { @@ -173,10 +225,46 @@ class Array { } } + // Invokes a callback with the (indices, value_ptr) for each cell in the + // array. If a callback returns a non-OK status, returns that else returns + // Status::OK(). + Status EachStatus( + std::function, T*)> f) { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + Status s = f(index, &values_[i]); + if (!s.ok()) { + return s; + } + } + return Status::OK(); + } + + // Invokes a callback with the (indices, value) for each cell in the array. + // If a callback returns a non-OK status, returns that else returns + // Status::OK(). + Status EachStatus( + std::function, T)> f) const { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + Status s = f(index, values_[i]); + if (!s.ok()) { + return s; + } + } + return Status::OK(); + } + // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. + // + // The type trait is required to avoid this overload participating too + // eagerly; a parameter pack can take zero or more elements, so we must + // restrict this to only parameter packs that are all of integral type. template - const T& operator()(Dims... dims) const { + typename std::enable_if::value, + const T&>::type + operator()(Dims... dims) const { // We are using a std::array to avoid having to allocate memory in this // function for performance reasons. std::array indexes{{static_cast(dims)...}}; @@ -186,7 +274,9 @@ class Array { // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. template - T& operator()(Dims... dims) { + typename std::enable_if::value, + T&>::type + operator()(Dims... dims) { // We are using a std::array to avoid having to allocate memory in this // function for performance reasons. std::array indexes{{static_cast(dims)...}}; @@ -255,6 +345,59 @@ class Array { bool operator!=(const Array& other) const { return !(*this == other); } + // Performs the equivalent of a slice operation on this array. + Array Slice(tensorflow::gtl::ArraySlice starts, + tensorflow::gtl::ArraySlice limits) const { + CHECK_EQ(starts.size(), num_dimensions()); + CHECK_EQ(limits.size(), num_dimensions()); + + std::vector sizes; + std::transform(starts.begin(), starts.end(), limits.begin(), + std::back_inserter(sizes), + [](int64 start, int64 limit) { return limit - start; }); + Array result(sizes); + + std::vector index(sizes_.size()); + int64 slice_i = 0; + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + if (array_impl::all_inside_range(index, starts, limits)) { + // Even though the bounds of result are different to our bounds, we're + // iterating in the same order. So we can simply write successive linear + // indices instead of recalculating a multi-dimensional index. + result.values_[slice_i++] = values_[i]; + } + } + return result; + } + + // Performs the equivalent of a DynamicUpdateSlice in-place on this array. + void UpdateSlice(const Array& from, + tensorflow::gtl::ArraySlice start_indices) { + CHECK_EQ(from.num_dimensions(), num_dimensions()); + std::vector limit_indices; + std::transform(start_indices.begin(), start_indices.end(), + from.dimensions().begin(), std::back_inserter(limit_indices), + std::plus{}); + std::vector index(sizes_.size()); + int64 from_i = 0; + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + if (array_impl::all_inside_range(index, start_indices, limit_indices)) { + // Even though the bounds of from are different to our bounds, we're + // iterating in the same order. So we can simply write successive linear + // indices instead of recalculating a multi-dimensional index. + values_[i] = from.values_[from_i++]; + } + } + } + + // Performs an in-place reshape, modifying the dimensions but not the + // underlying data. + void Reshape(tensorflow::gtl::ArraySlice new_dimensions) { + int64 old_num_elements = num_elements(); + sizes_ = std::vector(new_dimensions.begin(), new_dimensions.end()); + CHECK_EQ(num_elements(), old_num_elements); + } + // Returns a string representation of the array suitable for debugging. string ToString() const { std::vector pieces; diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index e9449f01ad69a5722f53cce09e2884e20a0def5a..a1c5840a5f3874e27043c821ed4684da2fa6c542 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -36,6 +36,8 @@ namespace xla { template class Array3D : public Array { public: + Array3D() : Array(std::vector{0, 0, 0}) {} + // Creates an array of dimensions n1 x n2 x n3, uninitialized values. Array3D(const int64 n1, const int64 n2, const int64 n3) : Array(std::vector{n1, n2, n3}) {} diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc index 093784f541b3bd18f4a1fc1b665cd0d17a892f28..8b9419477479d952126fd831eb44899e7649ca71 100644 --- a/tensorflow/compiler/xla/array_test.cc +++ b/tensorflow/compiler/xla/array_test.cc @@ -71,6 +71,19 @@ TEST(ArrayTest, IndexingReadWrite) { EXPECT_EQ(arr(1, 2), 61); } +TEST(ArrayTest, DynamicIndexingReadWrite) { + Array arr({2, 3}); + + std::vector index1 = {1, 1}; + std::vector index2 = {1, 2}; + EXPECT_EQ(arr(index1), 0); + EXPECT_EQ(arr(index2), 0); + arr(index1) = 51; + arr(index2) = 61; + EXPECT_EQ(arr(1, 1), 51); + EXPECT_EQ(arr(1, 2), 61); +} + TEST(ArrayTest, IndexingReadWriteBool) { Array arr{{false, true, false}, {false, true, false}}; @@ -141,5 +154,37 @@ TEST(ArrayTest, Each) { EXPECT_EQ(arr.num_elements() * (arr.num_elements() - 1) / 2, each_sum); } +TEST(ArrayTest, Slice) { + Array arr({2, 4}); + arr.FillWithMultiples(1); + + Array identity_slice = arr.Slice({0, 0}, {2, 4}); + EXPECT_EQ(identity_slice.dimensions(), arr.dimensions()); + for (auto it1 = arr.begin(), it2 = identity_slice.begin(), e = arr.end(); + it1 != e; ++it1, ++it2) { + EXPECT_EQ(*it1, *it2); + } + + Array sub_slice = arr.Slice({1, 0}, {2, 2}); + EXPECT_EQ(sub_slice.dimensions(), (std::vector{1, 2})); + const string expected = R"([[4, 5]])"; + EXPECT_EQ(expected, sub_slice.ToString()); +} + +TEST(ArrayTest, UpdateSlice) { + Array arr({3, 4}); + arr.FillWithMultiples(1); + + Array sub_arr({2, 2}); + sub_arr.FillWithMultiples(3); + + arr.UpdateSlice(sub_arr, {1, 1}); + + const string expected = R"([[0, 1, 2, 3], + [4, 0, 3, 7], + [8, 6, 9, 11]])"; + EXPECT_EQ(expected, arr.ToString()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 92cd8e729d659c4ff24c156d89f29275848c3cee..66937d64aff18817bbd5310e0c24e19556e9d727 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -142,8 +142,7 @@ StatusOr> Client::TransferFromOutfeed( "TransferToClient request"); } - Literal literal(response.literal()); - return MakeUnique(literal); + return MakeUnique(response.literal()); } Status Client::ResetDevice() { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index a716159f9e74041c4823ad20b46fa94c2d7b9d8c..c28380b689c7a0e16bf0bcbf15003f4aa15e42a7 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -67,6 +67,15 @@ class Client { std::vector arguments; ExecutionOptions execution_options; ExecutionProfile* execution_profile; + + ComputationInstance(const Computation& computation, + std::vector arguments, + ExecutionOptions execution_options, + ExecutionProfile* execution_profile) + : computation(computation), + arguments(std::move(arguments)), + execution_options(execution_options), + execution_profile(execution_profile) {} }; // Executes a list ComputationInstances and returns global data produced from @@ -133,7 +142,7 @@ class Client { // Returns a vector of global data handles that point to the tuple elements. StatusOr>> DeconstructTuple( - const GlobalData& computation); + const GlobalData& data); // Retrieves the statistics of the given computation. StatusOr GetComputationStats( diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 763d94e94c2167f47b3f0777a31815f02791aa9e..317dcb4e41723b93e7e50d911f16e48bc3505a09 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -153,6 +153,7 @@ bool ComputationBuilder::MakeWindow( } else { dim->set_window_dilation(1); } + dim->set_window_reversal(false); } return true; } @@ -624,7 +625,41 @@ ComputationDataHandle ComputationBuilder::Lt( ComputationDataHandle ComputationBuilder::Dot( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return BinaryOp(BINOP_DOT, lhs, rhs, /*broadcast_dimensions=*/{}); + StatusOr> lhs_shape_or_status = GetShape(lhs); + if (!lhs_shape_or_status.ok()) { + NoteError(lhs_shape_or_status.status()); + return ComputationDataHandle(); + } + std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); + + DotDimensionNumbers dimension_numbers; + dimension_numbers.add_lhs_contracting_dimensions( + lhs_shape->dimensions_size() == 1 ? 0 : 1); + dimension_numbers.add_rhs_contracting_dimensions(0); + return DotGeneral(lhs, rhs, dimension_numbers); +} + +ComputationDataHandle ComputationBuilder::DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + DotRequest request; + *request.mutable_lhs() = lhs; + *request.mutable_rhs() = rhs; + *request.mutable_dimension_numbers() = dimension_numbers; + + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_dot_request() = request; + AddCommonFieldsToOpRequest(&op_request); + OpResponse response; + + VLOG(2) << "making Dot request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); } ComputationDataHandle ComputationBuilder::Conv( @@ -693,11 +728,15 @@ bool ComputationBuilder::VerifyConvolution( } return true; }; - return check_spatial_dimensions("spatial_dimensions", - dimension_numbers.spatial_dimensions()) && + return check_spatial_dimensions( + "input_spatial_dimensions", + dimension_numbers.input_spatial_dimensions()) && check_spatial_dimensions( "kernel_spatial_dimensions", - dimension_numbers.kernel_spatial_dimensions()); + dimension_numbers.kernel_spatial_dimensions()) && + check_spatial_dimensions( + "output_spatial_dimensions", + dimension_numbers.output_spatial_dimensions()); } ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( @@ -729,11 +768,11 @@ ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( } std::vector base_area_dimensions( - dimension_numbers.spatial_dimensions_size()); + dimension_numbers.input_spatial_dimensions_size()); for (std::vector::size_type i = 0; i < base_area_dimensions.size(); ++i) { base_area_dimensions[i] = - lhs_shape->dimensions(dimension_numbers.spatial_dimensions(i)); + lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i)); } std::vector window_dimensions( @@ -1163,6 +1202,34 @@ ComputationDataHandle ComputationBuilder::ConvertElementType( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::BitcastConvertType( + const ComputationDataHandle& operand, PrimitiveType new_element_type) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + StatusOr> shape_status = GetShape(operand); + if (!shape_status.ok()) { + first_error_ = shape_status.status(); + return ComputationDataHandle(); + } + std::unique_ptr original = shape_status.ConsumeValueOrDie(); + + ConvertRequest request; + *request.mutable_operand() = operand; + request.set_new_element_type(new_element_type); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_bitcast_convert_request() = request; + AddCommonFieldsToOpRequest(&op_request); + OpResponse response; + + VLOG(2) << "making bitcast convert request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + ComputationDataHandle ComputationBuilder::SquareF32( const ComputationDataHandle& operand) { return BinaryOp(BINOP_POW, operand, ConstantR0(2.0), @@ -1437,6 +1504,34 @@ ComputationDataHandle ComputationBuilder::While( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::Conditional( + const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const Computation& true_computation, + const ComputationDataHandle& false_operand, + const Computation& false_computation) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ConditionalRequest request; + *request.mutable_predicate() = predicate; + *request.mutable_true_operand() = true_operand; + *request.mutable_true_computation() = true_computation.handle(); + *request.mutable_false_operand() = false_operand; + *request.mutable_false_computation() = false_computation.handle(); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_conditional_request() = request; + AddCommonFieldsToOpRequest(&op_request); + OpResponse response; + + VLOG(2) << "making conditional op request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + ComputationDataHandle ComputationBuilder::Reduce( const ComputationDataHandle& operand, const ComputationDataHandle& init_value, const Computation& computation, @@ -1816,25 +1911,27 @@ ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dimension_numbers.set_kernel_input_feature_dimension( kConvKernelInputDimension); for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_spatial_dimensions(i + 2); + dimension_numbers.add_input_spatial_dimensions(i + 2); dimension_numbers.add_kernel_spatial_dimensions(i + 2); + dimension_numbers.add_output_spatial_dimensions(i + 2); } return dimension_numbers; } /* static */ StatusOr ComputationBuilder::CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 output_batch, - int64 output_feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 input_first_spatial, + int64 input_second_spatial, int64 output_batch, int64 output_feature, + int64 output_first_spatial, int64 output_second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial) { - if (std::set( - {input_batch, input_feature, first_spatial, second_spatial}) + if (std::set({input_batch, input_feature, input_first_spatial, + input_second_spatial}) .size() != 4) { return FailedPrecondition( "dimension numbers for the input are not unique: (%lld, %lld, %lld, " "%lld)", - input_batch, input_feature, first_spatial, second_spatial); + input_batch, input_feature, input_first_spatial, input_second_spatial); } if (std::set({kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial}) @@ -1845,25 +1942,28 @@ ComputationBuilder::CreateConvDimensionNumbers( kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial); } - if (std::set( - {output_batch, output_feature, first_spatial, second_spatial}) + if (std::set({output_batch, output_feature, output_first_spatial, + output_second_spatial}) .size() != 4) { return FailedPrecondition( "dimension numbers for the output are not unique: (%lld, %lld, %lld, " "%lld)", - output_batch, output_feature, first_spatial, second_spatial); + output_batch, output_feature, output_first_spatial, + output_second_spatial); } ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(input_batch); dimension_numbers.set_input_feature_dimension(input_feature); - dimension_numbers.set_output_batch_dimension(output_batch); - dimension_numbers.set_output_feature_dimension(output_feature); - dimension_numbers.add_spatial_dimensions(first_spatial); - dimension_numbers.add_spatial_dimensions(second_spatial); + dimension_numbers.add_input_spatial_dimensions(input_first_spatial); + dimension_numbers.add_input_spatial_dimensions(input_second_spatial); dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); + dimension_numbers.set_output_batch_dimension(output_batch); + dimension_numbers.set_output_feature_dimension(output_feature); + dimension_numbers.add_output_spatial_dimensions(output_first_spatial); + dimension_numbers.add_output_spatial_dimensions(output_second_spatial); return dimension_numbers; } diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 9159b2661451e73b1ccd2a2c1f01dfad61792c99..28889ece73f5da72c3eea681c9e4aea7351d3d54 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -68,6 +68,7 @@ class ShardingBuilder { const TileAssignment& tile_assignment) { OpSharding result; result.set_type(OpSharding::Type::OpSharding_Type_OTHER); + *result.mutable_tile_shape() = tile_shape; for (int64 dim : tile_assignment.dimensions()) { result.add_tile_assignment_dimensions(dim); } @@ -120,14 +121,10 @@ class ComputationBuilder { // result, OpMetadata is set on the Computation Builder. All subsequent // instructions generated via this Computation Builder will have the same // OpMetadata attached until a call to ClearOpMetdata. - void SetOpMetadata(const OpMetadata& metadata) { - metadata_ = metadata; - } + void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } // Clears the HloMetadata state. - void ClearOpMetadata() { - metadata_.Clear(); - } + void ClearOpMetadata() { metadata_.Clear(); } // Sets an OpSharding that will be attached to all instructions until cleared. void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } @@ -396,6 +393,11 @@ class ComputationBuilder { ComputationDataHandle Dot(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs); + // Enqueues a general dot instruction onto the computation. + ComputationDataHandle DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers); + // Default dimension numbers used for a 2D convolution. static constexpr int64 kConvBatchDimension = 0; static constexpr int64 kConvFeatureDimension = 1; @@ -416,8 +418,9 @@ class ComputationBuilder { // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an // error if either the input or the weight dimension numbers have conflicts. static StatusOr CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 output_batch, - int64 output_feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 input_first_spatial, + int64 input_second_spatial, int64 output_batch, int64 output_feature, + int64 output_first_spatial, int64 output_second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial); @@ -672,6 +675,13 @@ class ComputationBuilder { ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, PrimitiveType new_element_type); + // Enqueues a no-op instruction onto the computation that changes + // the element type of the operand array to primitive_type. The + // bit-widths of the source and destination element types must be + // identical. + ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand, + PrimitiveType new_element_type); + // Enqueues a float32 reciprocal instruction onto the computation. // (float32 is specified as there is an implicit float32 -1.0f constant // exponent). @@ -731,6 +741,13 @@ class ComputationBuilder { const Computation& body, const ComputationDataHandle& init); + // Enqueues a conditional node onto the computation. + ComputationDataHandle Conditional(const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const Computation& true_computation, + const ComputationDataHandle& false_operand, + const Computation& false_computation); + // Enqueues a ReducePrecision node onto the computation. ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, const int exponent_bits, diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index ee3468208792879c3fe4ff5860e434ef5a0c0155..fca2bf2688cd21b44f099da3bae3b890cbb069ab 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -44,6 +44,7 @@ cc_library( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index e6645e4941bd04c658b67117bb689f6fdef7dfc1..5f2b55713e342aa3d0251386d57cb52481fe748d 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -48,65 +49,9 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { - std::vector> elements; - for (const Shape& element_shape : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr element, - MakeFakeLiteral(element_shape)); - elements.push_back(std::move(element)); - } - return Literal::MakeTupleOwned(std::move(elements)); - } - std::unique_ptr literal = Literal::CreateFromShape(shape); - std::minstd_rand0 engine; - switch (shape.element_type()) { - case F32: { - std::uniform_real_distribution generator(0.0f, 1.0f); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - case S32: { - std::uniform_int_distribution generator( - std::numeric_limits::lowest(), - std::numeric_limits::max()); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - case S64: { - std::uniform_int_distribution generator( - std::numeric_limits::lowest(), - std::numeric_limits::max()); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - case PRED: { - std::uniform_int_distribution generator(0, 1); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - default: - return Unimplemented("Unsupported type for fake literal generation: %s", - ShapeUtil::HumanString(shape).c_str()); - } - return std::move(literal); -} - std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { - if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) { + if (ShapeUtil::ByteSizeOf(shape) < (1LL << 20)) { StatusOr> literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index b5c4393dcc3e37c03a5b0e1a806b0f8b07a132ed..7e640d1307edcc3e2c021f4391c456f578a015ee 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -26,10 +26,6 @@ limitations under the License. namespace xla { -// Generates fake data in a literal of the given shape, or returns an error -// status if the element type is currently unhandled for fake data generation. -StatusOr> MakeFakeLiteral(const Shape& shape); - // Generates fake data of the given shape on the device or dies. The fake data // is created by performing a computation on the device rather than transferring // data from the host to the device. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 15c744ecd349e91dc703bec5708d78a896f132c3..b051955f0fd85b7ca886bc0238068aeb94427209 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -27,16 +27,6 @@ namespace se = ::perftools::gputools; namespace xla { -ExecutableBuildOptions& ExecutableBuildOptions::set_platform( - perftools::gputools::Platform* platform) { - platform_ = platform; - return *this; -} - -perftools::gputools::Platform* ExecutableBuildOptions::platform() const { - return platform_; -} - ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( int device_ordinal) { device_ordinal_ = device_ordinal; @@ -56,16 +46,6 @@ const Shape* ExecutableBuildOptions::result_layout() const { return result_layout_set_ ? &result_layout_ : nullptr; } -ExecutableBuildOptions& ExecutableBuildOptions::set_has_hybrid_result( - bool has_hybrid_result) { - has_hybrid_result_ = has_hybrid_result; - return *this; -} - -bool ExecutableBuildOptions::has_hybrid_result() const { - return has_hybrid_result_; -} - namespace { StatusOr BorrowStreamForDevice(int device_ordinal, Backend* backend) { @@ -230,9 +210,9 @@ tensorflow::Status LocalExecutable::RecordArguments( SessionModule* session_module) { session_module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - Literal literal; - TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal)); - *session_module->add_arguments() = literal.ToProto(); + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + LiteralFromShapedBuffer(*argument)); + *session_module->add_arguments() = literal->ToProto(); } return Status::OK(); } @@ -240,21 +220,19 @@ tensorflow::Status LocalExecutable::RecordArguments( tensorflow::Status LocalExecutable::RecordResult( const ShapedBuffer* result, SessionModule* session_module) { session_module->clear_result(); - Literal literal(session_module->result()); - TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal)); - *session_module->mutable_result() = literal.ToProto(); + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + LiteralFromShapedBuffer(*result)); + *session_module->mutable_result() = literal->ToProto(); return Status::OK(); } -// TODO(dnovillo) Change signature to return StatusOr. -tensorflow::Status LocalExecutable::LiteralFromShapedBuffer( - const ShapedBuffer& shaped_buffer, Literal* literal) { +StatusOr> LocalExecutable::LiteralFromShapedBuffer( + const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, backend_->stream_executor(shaped_buffer.device_ordinal())); - return backend_->transfer_manager()->TransferLiteralFromDevice( - executor, shaped_buffer.buffer({}), shaped_buffer.shape(), - shaped_buffer.shape(), literal); + return backend_->transfer_manager()->TransferLiteralFromDevice(executor, + shaped_buffer); } se::Platform* LocalClient::platform() const { @@ -297,9 +275,6 @@ StatusOr> LocalClient::Compile( device_ordinal, options)); } -// Copy the literal data to the device with the given ordinal and return as a -// ScopedShapedBuffer. The given memory allocator is used for device memory -// allocation. StatusOr> LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator) { @@ -308,46 +283,42 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, } TF_ASSIGN_OR_RETURN( auto scoped_buffer, - ScopedShapedBuffer::Allocate(literal.shape(), allocator, device_ordinal)); + ScopedShapedBuffer::Allocate( + literal.shape(), allocator, device_ordinal, + [this](const Shape& shape) { + return backend().transfer_manager()->GetByteSizeRequirement(shape); + })); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - literal.shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { - // This is a leaf of the shape. Transfer the literal array data to the - // device buffer. - return backend().transfer_manager()->TransferLiteralToDevice( - executor, literal.GetSubliteral(index), - scoped_buffer->mutable_buffer(index)); - } - return Status::OK(); - })); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + executor, literal, *scoped_buffer)); return std::move(scoped_buffer); } -// Copy the data from the device contained in the given ShapedBuffer and -// return as a Literal. StatusOr> LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { - std::unique_ptr literal = - Literal::CreateFromShape(shaped_buffer.shape()); TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, backend().stream_executor(shaped_buffer.device_ordinal())); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - literal->shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { - // This is a leaf of the shape. Transfer the device buffer into the - // literal. The layout of the literal and the device buffer are - // necessarily the same so we pass 'subshape' for both device and - // literal shapes. - return backend().transfer_manager()->TransferLiteralFromDevice( - executor, shaped_buffer.buffer(index), - /*device_shape=*/subshape, - /*literal_shape*/ subshape, &literal->GetSubliteral(index)); - } - return Status::OK(); - })); + return backend().transfer_manager()->TransferLiteralFromDevice(executor, + shaped_buffer); +} + +Status LocalClient::TransferToInfeedLocal(const Literal& literal, + int device_ordinal) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + backend().stream_executor(device_ordinal)); + return backend().transfer_manager()->TransferLiteralToInfeed(executor, + literal); +} + +StatusOr> LocalClient::TransferFromOutfeedLocal( + const Shape& shape, int device_ordinal) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + backend().stream_executor(device_ordinal)); + auto literal = MakeUnique(); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( + executor, shape, literal.get())); return std::move(literal); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 9f985ed5275815de2d59f6caedbbcc8060420a13..3ca0d2ef5513cfb6b0dbfbc63b311f81a318356e 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -37,14 +37,6 @@ namespace xla { // LocalClient::Compile. class ExecutableBuildOptions { public: - // If set, this is the platform to build the computation for. This must match - // the underlying platform of the service. A value of nullptr indicates the - // option has not been set. - // - // TODO(b/28616830): Support multiple platforms. - ExecutableBuildOptions& set_platform(perftools::gputools::Platform* platform); - perftools::gputools::Platform* platform() const; - // If set, this is the device to build the computation for. Valid // device_ordinal values are: 0 to # of devices - 1. These values are // identical to the device ordinal values used by StreamExecutor. The built @@ -61,18 +53,10 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); const Shape* result_layout() const; - // If set, the executable will be built to output a hybrid - // ShapedBuffer with top-level tuple pointers in host memory and - // result buffers in device memory. - ExecutableBuildOptions& set_has_hybrid_result(bool has_hybrid_result); - bool has_hybrid_result() const; - private: - perftools::gputools::Platform* platform_ = nullptr; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; - bool has_hybrid_result_ = true; }; class LocalExecutable { @@ -129,9 +113,9 @@ class LocalExecutable { tensorflow::Status RecordResult(const ShapedBuffer* result, SessionModule* session_module); - // Copies the contents of a ShapedBuffer into a Literal proto. - tensorflow::Status LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer, - Literal* literal); + // Returns a literal containing the contents of the given ShapedBuffer. + StatusOr> LiteralFromShapedBuffer( + const ShapedBuffer& shaped_buffer); // Compiled computation. std::unique_ptr executable_; @@ -178,6 +162,20 @@ class LocalClient : public Client { StatusOr> ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); + // Transfer the given literal to the infeed queue of the given device. + // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does + // not inherit from Client and there is no possibility of confusion with + // Client::TransferToInfeed. + Status TransferToInfeedLocal(const Literal& literal, int device_ordinal); + + // Transfer and return a value of the given shape from the outfeed of the + // given device. + // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does + // not inherit from Client and there is no possibility of confusion with + // Client::TransferFromOutfeed. + StatusOr> TransferFromOutfeedLocal( + const Shape& shape, int device_ordinal); + // Returns the platform that the underlying service targets. perftools::gputools::Platform* platform() const; diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index f2cdd9669c727bb778fce495ede0faaf2d9a923d..bfafef0a40f55e13ac94b2d1750df25146081784 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -31,7 +31,6 @@ std::vector* flag_objects; std::once_flag flags_init; void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_hlo_graph_path("/tmp/"); flags->set_xla_enable_fast_math(true); flags->set_xla_llvm_enable_alias_scope_metadata(true); flags->set_xla_llvm_enable_noalias_metadata(true); @@ -117,9 +116,22 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), flag_values->xla_hlo_dump_as_graphdef(), "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag( + "xla_hlo_graph_sharding_color", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), + flag_values->xla_hlo_graph_sharding_color(), + "Assign colors based on sharding assignments when generating the " + "HLO graphs."), + tensorflow::Flag( + "xla_hlo_tfgraph_device_scopes", + bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes), + flag_values->xla_hlo_tfgraph_device_scopes(), + "When generating TensorFlow HLO graphs, if the HLO instructions " + "are assigned to a specific device, prefix the name scope with " + "\"devX\" with X being the device ordinal."), tensorflow::Flag( "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), - "HLO modules matching this regex will be dumped to LOG(INFO). "), + "HLO modules matching this regex will be dumped to LOG(INFO)."), tensorflow::Flag( "xla_generate_hlo_text_to", flag_values->mutable_xla_generate_hlo_text_to(), diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index fda791401d567b694b3d2cabf129141a7ff2ddb2..42c9d21149a41a3d60f2cfff65d3af08d7c8b9d7 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -33,6 +33,20 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +namespace { +using tensorflow::int64; + +constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; + +// Converts between little and big endian, assuming elements in the array are 16 +// bits long. +void ConvertEndianShort(char* bytes, int64 size) { + CHECK_EQ(size / 2, 0); + for (int64 i = 0; i < size; i += 2) { + std::swap(bytes[i], bytes[i + 1]); + } +} +} // namespace namespace xla { @@ -169,6 +183,8 @@ Status Literal::Copy(const Literal& src_literal, return CopyRange(src_literal, src_base, dest_base, copy_size); case F16: return CopyRange(src_literal, src_base, dest_base, copy_size); + case BF16: + return CopyRange(src_literal, src_base, dest_base, copy_size); case F32: return CopyRange(src_literal, src_base, dest_base, copy_size); case F64: @@ -200,6 +216,8 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(0); case F16: return *Literal::CreateR0(static_cast(0.0f)); + case BF16: + return *Literal::CreateR0(static_cast(0.0f)); case F32: return *Literal::CreateR0(0); case F64: @@ -234,6 +252,10 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(1); case S64: return *Literal::CreateR0(1); + case F16: + return *Literal::CreateR0(static_cast(1.0f)); + case BF16: + return *Literal::CreateR0(static_cast(1.0f)); case F32: return *Literal::CreateR0(1); case F64: @@ -245,8 +267,6 @@ Status Literal::Copy(const Literal& src_literal, case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - return *Literal::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -285,6 +305,9 @@ Status Literal::Copy(const Literal& src_literal, case F16: return *Literal::CreateR0( static_cast(-std::numeric_limits::infinity())); + case BF16: + return *Literal::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -321,6 +344,9 @@ Status Literal::Copy(const Literal& src_literal, case F16: return *Literal::CreateR0( static_cast(std::numeric_limits::infinity())); + case BF16: + return *Literal::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -428,6 +454,7 @@ std::unique_ptr Literal::Transpose( // The shape with affine layout resulting from that operation will be // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the // most minor. + // // Essentially, given MinMaj(Di) the position of the Di dimension within the // minor to major vector, and given T(Di) the index that the original Di // dimension has within the transposed array, a layout is affine if @@ -536,6 +563,9 @@ string Literal::GetAsString( } case F16: return tensorflow::strings::StrCat(Get(multi_index)); + case BF16: + return tensorflow::strings::StrCat( + static_cast(Get(multi_index))); default: return tensorflow::strings::StrCat( "[", PrimitiveType_Name(shape().element_type()), "]"); @@ -569,9 +599,17 @@ int64 Literal::LinearIndex( return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); } -string Literal::ToString() const { +string Literal::ToString(bool print_layout) const { std::vector pieces; + auto shape_to_string = [print_layout](const Shape& shape) { + if (print_layout) { + return ShapeUtil::HumanStringWithLayout(shape); + } else { + return ShapeUtil::HumanString(shape); + } + }; + auto element_to_string = [this](tensorflow::gtl::ArraySlice indices) -> string { PrimitiveType element_type = shape().element_type(); @@ -585,7 +623,7 @@ string Literal::ToString() const { // TODO(b/32894291): refactor this code to reduce code duplication. if (ShapeUtil::IsTuple(shape())) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" (\n"); pieces.push_back(tensorflow::str_util::Join( tuple_literals(), ",\n", [](string* out, const Literal& element) { @@ -601,7 +639,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 2) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(" { "); @@ -613,7 +651,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 3) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(i0 > 0 ? ",\n{" : "{"); @@ -628,7 +666,7 @@ string Literal::ToString() const { } pieces.push_back("\n}"); } else if (ShapeUtil::Rank(shape()) == 4) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); @@ -649,7 +687,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 5) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); @@ -676,8 +714,14 @@ string Literal::ToString() const { } pieces.push_back("}"); } else { - pieces.push_back(ShapeUtil::HumanString(shape())); - pieces.push_back(" {...}"); + pieces.push_back(shape_to_string(shape())); + pieces.push_back(" {"); + EachCellAsString( + [&](tensorflow::gtl::ArraySlice indices, const string& value) { + pieces.push_back(" "); + pieces.push_back(value); + }); + pieces.push_back("}"); } return tensorflow::str_util::Join(pieces, ""); @@ -735,6 +779,8 @@ void* Literal::MutableInternalData() { return reinterpret_cast(c64s_.data()); case F16: return reinterpret_cast(f16s_.data()); + case BF16: + return reinterpret_cast(bf16s_.data()); default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(shape().element_type()); @@ -777,6 +823,9 @@ void Literal::Reserve(int64 num_elements) { case F16: Resize(num_elements, static_cast(0.0f)); break; + case BF16: + Resize(num_elements, static_cast(0.0f)); + break; default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(shape().element_type()); @@ -816,6 +865,9 @@ tensorflow::Status Literal::ValidateLiteral() const { case F16: actual = f16s().size() / sizeof(half); break; + case BF16: + actual = bf16s().size(); + break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + @@ -912,6 +964,7 @@ StatusOr> ConvertIfDestTypeMatches( CONVERT_IF_TYPES_MATCH(F16) CONVERT_IF_TYPES_MATCH(F32) CONVERT_IF_TYPES_MATCH(F64) + CONVERT_IF_TYPES_MATCH(BF16) #undef CONVERT_IF_TYPES_MATCH case C64: return ConvertToC64(src_literal); @@ -941,8 +994,9 @@ StatusOr> Literal::Convert( CONVERT_IF_DEST_TYPE_MATCHES(F16) CONVERT_IF_DEST_TYPE_MATCHES(F32) CONVERT_IF_DEST_TYPE_MATCHES(F64) + CONVERT_IF_DEST_TYPE_MATCHES(BF16) #undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. + // Other types are not yet supported. default: return InvalidArgument("Unimplemented: Convert from type %s to type %s", PrimitiveType_Name(shape().element_type()).c_str(), @@ -1011,6 +1065,8 @@ bool Literal::operator==(const Literal& other) const { return EqualElements(*this, other, 0, &multi_index); case F16: return EqualElements(*this, other, 0, &multi_index); + case BF16: + return EqualElements(*this, other, 0, &multi_index); case C64: return EqualElements(*this, other, 0, &multi_index); default: @@ -1120,13 +1176,18 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - // TODO - there is an endianess problem here. fix it, or wait for uint16 - // support in protobuf auto values = mutable_f16s(); return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } +template <> +tensorflow::gtl::MutableArraySlice +Literal::GetMutableArraySlice() { + auto values = mutable_bf16s(); + return {values->data(), values->size()}; +} + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { CHECK_EQ(shape().element_type(), PRED); @@ -1197,6 +1258,12 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { f16s().size() / sizeof(half)); } +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), BF16); + return {bf16s().data(), bf16s().size()}; +} + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { @@ -1245,6 +1312,9 @@ bool Literal::IsAll(int8 value) const { return AllElementsEqualValue(*this, value); case F16: return AllElementsEqualValue(*this, static_cast(value)); + case BF16: + return AllElementsEqualValue(*this, + static_cast(value)); case PRED: if (value == 0) { return AllElementsEqualValue(*this, false); @@ -1266,6 +1336,9 @@ bool Literal::IsAllFloat(float value) const { return AllElementsEqualValue(*this, value); case F16: return AllElementsEqualValue(*this, static_cast(value)); + case BF16: + return AllElementsEqualValue(*this, + static_cast(value)); default: return false; } @@ -1302,6 +1375,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { return Get(indices) == complex64(0.0f, 0.0f); case F16: return Get(indices) == static_cast(0.0f); + case BF16: + return Get(indices) == static_cast(0.0f); case PRED: return Get(indices) == false; default: @@ -1369,6 +1444,12 @@ void Literal::Resize(int64 num_elements, half value) { mutable_f16s()->resize(num_elements, value); } +template <> +void Literal::Resize(int64 num_elements, bfloat16 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_bf16s()->resize(num_elements, value); +} + template <> void Literal::Resize(int64 num_elements, complex64 value) { CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); @@ -1417,6 +1498,19 @@ LiteralProto Literal::ToProto() const { *proto.mutable_f16s() = string(reinterpret_cast(f16s_.data()), f16s_.size() * sizeof(half)); + if (!kLittleEndian) { + ConvertEndianShort(const_cast(proto.mutable_f16s()->data()), + proto.f16s().size()); + } + break; + case BF16: + *proto.mutable_bf16s() = + string(reinterpret_cast(bf16s_.data()), + bf16s_.size() * sizeof(bfloat16)); + if (!kLittleEndian) { + ConvertEndianShort(const_cast(proto.mutable_bf16s()->data()), + proto.bf16s().size()); + } break; case F32: CopyToRepeatedField(proto.mutable_f32s(), f32s()); @@ -1485,6 +1579,21 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { CHECK_EQ(0, s.size() % sizeof(half)); f16s_ = std::vector(s.size() / sizeof(half)); memcpy(f16s_.data(), s.data(), s.size()); + + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(f16s_.data()), s.size()); + } + break; + } + case BF16: { + const string& s(literal_proto.bf16s()); + CHECK_EQ(0, s.size() % sizeof(bfloat16)); + bf16s_ = std::vector(s.size() / sizeof(bfloat16)); + memcpy(bf16s_.data(), s.data(), s.size()); + + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(bf16s_.data()), s.size()); + } break; } case F32: diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index a1e288829f22835f94c6e3c041796f84d995211c..2981f9f8753a60f7acb7e3c6bf86f2b9da4c96d8 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -99,6 +99,7 @@ class Literal { f16s_.clear(); f32s_.clear(); f64s_.clear(); + c64s_.clear(); tuple_literals_.clear(); } @@ -163,6 +164,11 @@ class Literal { const std::vector& c64s() const { return c64s_; } std::vector* mutable_c64s() { return &c64s_; } + int bf16s_size() const { return bf16s().size(); } + bfloat16 bf16s(int i) const { return bf16s_[i]; } + const std::vector& bf16s() const { return bf16s_; } + std::vector* mutable_bf16s() { return &bf16s_; } + int tuple_literals_size() const { return tuple_literals().size(); } const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } Literal* add_tuple_literals() { @@ -280,11 +286,11 @@ class Literal { std::unique_ptr Relayout(const Layout& new_layout, const ShapeIndex& shape_index = {}) const; - // Creates a new literal by reshaping this literal to have 'shape'. Both the - // original shape and 'shape' must contain the same number of elements. The + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. StatusOr> Reshape( - tensorflow::gtl::ArraySlice shape) const; + tensorflow::gtl::ArraySlice dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -450,7 +456,7 @@ class Literal { tensorflow::Status ValidateLiteral() const; // Returns a string representation of the literal value. - string ToString() const; + string ToString(bool print_layout = false) const; // Invokes the "per cell" callback for each element in the provided // literal with the element's indices and a string representation of @@ -622,6 +628,7 @@ class Literal { std::vector u16s_; std::vector u32s_; std::vector u64s_; + std::vector bf16s_; std::vector f16s_; std::vector f32s_; std::vector f64s_; @@ -674,6 +681,9 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; @@ -714,6 +724,9 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); @@ -747,6 +760,9 @@ void Literal::Resize(int64 num_elements, double value); template <> void Literal::Resize(int64 num_elements, half value); +template <> +void Literal::Resize(int64 num_elements, bfloat16 value); + template <> void Literal::Resize(int64 num_elements, complex64 value); @@ -990,6 +1006,14 @@ inline half Literal::Get( return GetArraySlice()[linear_index]; } +template <> +inline bfloat16 Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == BF16); + int64 linear_index = LinearIndex(multi_index); + return GetArraySlice()[linear_index]; +} + template void Literal::Set(tensorflow::gtl::ArraySlice multi_index, NativeT value) { diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 6d596da4ada82ea67c098eeb629d1e19b77dd4c4..7ff64c4134155e7fe22ab99584970a7d6d6e8803 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -110,6 +110,18 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto c64_lit = Literal::CreateR0({3.14f, 2.78f}); ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); + + auto bf16_lit = Literal::CreateR0(static_cast(0.5f)); + ASSERT_EQ("0.5", bf16_lit->ToString()); + + // 3.14 will be truncated to 3.125 in bfloat16 format. + auto bf16_lit_truncated = + Literal::CreateR0(static_cast(3.14f)); + ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); + + auto bf16_lit_truncated2 = + Literal::CreateR0(static_cast(9.001f)); + ASSERT_EQ("9", bf16_lit_truncated2->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -397,6 +409,18 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); + bfloat16 b8(8.0f); + bfloat16 b9(9.0f); + + EXPECT_TRUE(Literal::CreateR2({{b8}, {b8}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{b8}, {b9}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{b9}, {b8}})->IsAll(8)); + + // 9.001 will be truncated to 9.0 + bfloat16 b91(9.001f); + bfloat16 b90(9.00f); + EXPECT_TRUE(Literal::CreateR2({{b91}, {b90}})->IsAll(9.0)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); @@ -491,7 +515,7 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { TEST_F(LiteralUtilTest, ReshapeR0) { auto original = Literal::CreateR0(1.7f); - auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie(); + auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); EXPECT_EQ(*original, *reshape); } @@ -691,6 +715,30 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { + Literal output; + bfloat16 h(0.25f); + output.PopulateWithValue(h, {}); + auto expected = Literal::CreateR0(h); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { + Literal output; + bfloat16 h(0.5f); + output.PopulateWithValue(h, {3}); + auto expected = Literal::CreateR1({h, h, h}); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { + Literal output; + bfloat16 h(2.0f); + output.PopulateWithValue(h, {2, 2}); + auto expected = Literal::CreateR2({{h, h}, {h, h}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; output.PopulateWithValue(2.5f, {}); @@ -975,6 +1023,14 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{half(26.0), half(0.0), half(28.0), half(0.0)}, {half(0.0), half(31.0), half(0.0), half(33.0)}}, }}, layout_r4_dim0major_); + auto bf16 = Literal::CreateR4WithLayout({{ + {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}}, + {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)}, + {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}}, + {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}}, + }}, layout_r4_dim0major_); auto f32 = Literal::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, @@ -1008,6 +1064,12 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = s8->Convert(PRED).ConsumeValueOrDie(); EXPECT_EQ(*conv, *pred); + conv = bf16->Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *s32); + + conv = bf16->Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *f32); + conv = pred->Convert(S32).ConsumeValueOrDie(); EXPECT_EQ(*conv, *int32_pred); diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 2113b5e06f3eb0169be50c0ee731a903c0eece9d..2bce56b7bd2f91f20ea670d0e7ccaa432c2b5f9f 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType() { return F64; } +template <> +PrimitiveType NativeToPrimitiveType() { + return BF16; +} + template <> PrimitiveType NativeToPrimitiveType() { return F16; @@ -89,7 +94,7 @@ PrimitiveType NativeToPrimitiveType() { } bool IsFloatingPointType(PrimitiveType type) { - return type == F16 || type == F32 || type == F64; + return type == F16 || type == F32 || type == F64 || type == BF16; } bool IsComplexType(PrimitiveType type) { return type == C64; } @@ -118,6 +123,7 @@ int BitWidth(PrimitiveType type) { case S16: case U16: case F16: + case BF16: return 16; case U32: diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index a49c8b86fcfe156ea3733ce05c0fb7337cf60dce..cb4583d198b454be1432134a9f6a77dbbbe5bdd8 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -26,6 +26,13 @@ limitations under the License. namespace xla { namespace primitive_util { +// The number of exponent bits in a BF16 value. +const int kBFloat16ExponentBits = 8; + +// The number of mantissa bits in a BF16 value. There is an implicit leading +// 1, so there is an implicit additional bit of precision. +const int kBFloat16MantissaBits = 7; + // Returns the XLA primitive type (eg, F32) corresponding to the given // template parameter native type (eg, float). template @@ -77,6 +84,8 @@ template <> PrimitiveType NativeToPrimitiveType(); template <> PrimitiveType NativeToPrimitiveType(); +template <> +PrimitiveType NativeToPrimitiveType(); // Complex template <> @@ -167,6 +176,11 @@ struct PrimitiveTypeToNative { using type = half; }; +template <> +struct PrimitiveTypeToNative { + using type = bfloat16; +}; + // Complex template <> struct PrimitiveTypeToNative { diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/ptr_util.h index fa670303136ebff0c3e0e32f5c64e879c46fe964..c58c19db2cacbe9b038160f27b9bd76aa58146eb 100644 --- a/tensorflow/compiler/xla/ptr_util.h +++ b/tensorflow/compiler/xla/ptr_util.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ -// Utility functions for pointers. +// As this was moved to tensorflow/core/util, provide indirections here to +// maintain current functionality of the library. #include @@ -24,55 +25,27 @@ limitations under the License. #include #include -namespace xla { - -namespace internal { - -// Trait to select overloads and return types for MakeUnique. -template -struct MakeUniqueResult { - using scalar = std::unique_ptr; -}; -template -struct MakeUniqueResult { - using array = std::unique_ptr; -}; -template -struct MakeUniqueResult { - using invalid = void; -}; +#include "tensorflow/core/util/ptr_util.h" -} // namespace internal +namespace xla { -// Transfers ownership of a raw pointer to a std::unique_ptr of deduced type. -// Example: -// X* NewX(int, int); -// auto x = WrapUnique(NewX(1, 2)); // 'x' is std::unique_ptr. -// -// WrapUnique is useful for capturing the output of a raw pointer factory. -// However, prefer 'MakeUnique(args...) over 'WrapUnique(new T(args...))'. -// auto x = WrapUnique(new X(1, 2)); // works, but nonideal. -// auto x = MakeUnique(1, 2); // safer, standard, avoids raw 'new'. -// -// Note: Cannot wrap pointers to array of unknown bound (i.e. U(*)[]). template std::unique_ptr WrapUnique(T* ptr) { - static_assert(!std::is_array::value || std::extent::value != 0, - "types T[0] or T[] are unsupported"); - return std::unique_ptr(ptr); + return tensorflow::WrapUnique(ptr); } template -typename internal::MakeUniqueResult::scalar MakeUnique(Args&&... args) { - return std::unique_ptr(new T(std::forward(args)...)); +typename tensorflow::helper::MakeUniqueResult::scalar MakeUnique( + Args&&... args) { + return tensorflow::MakeUnique(std::forward(args)...); } // Overload for array of unknown bound. // The allocation of arrays needs to use the array form of new, // and cannot take element constructor arguments. template -typename internal::MakeUniqueResult::array MakeUnique(size_t n) { - return std::unique_ptr(new typename std::remove_extent::type[n]()); +typename tensorflow::helper::MakeUniqueResult::array MakeUnique(size_t n) { + return tensorflow::MakeUnique(n); } } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 90aa9720a1e18bad06842adeead46fc3120d01dd..bdf92eaed1ff1d83cf03eec4d126677ea42c577f 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -102,7 +102,9 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( const Array3D& lhs, const Array3D& rhs, int64 kernel_stride, Padding padding, int64 lhs_dilation, int64 rhs_dilation, const ConvolutionDimensionNumbers& dnums) { - CHECK_EQ(dnums.spatial_dimensions_size(), 1); + CHECK_EQ(dnums.input_spatial_dimensions_size(), 1); + CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1); + CHECK_EQ(dnums.output_spatial_dimensions_size(), 1); // Reuse the code for Array4D-convolution by extending the 3D input into a 4D // array by adding a fourth dummy dimension of size 1 without stride, padding // and dilation. @@ -120,8 +122,9 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( }); // Add a second dummy spatial dimensions. ConvolutionDimensionNumbers dnums2d = dnums; - dnums2d.add_spatial_dimensions(3); + dnums2d.add_input_spatial_dimensions(3); dnums2d.add_kernel_spatial_dimensions(3); + dnums2d.add_output_spatial_dimensions(3); std::unique_ptr> convr4 = ConvArray4DGeneralDimensionsDilated( a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1}, {rhs_dilation, 1}, dnums2d); @@ -192,14 +195,26 @@ ReferenceUtil::ReduceWindow1DGeneric( const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding) { std::vector dim_lengths{static_cast(operand.size())}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + return ReduceWindow1DGeneric( + operand, init, reduce_func, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { + std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique>(window_counts[0]); @@ -465,9 +480,9 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( } ordered_input_dimensions[0] = - lhs_literal->shape().dimensions(dnums.spatial_dimensions(0)); + lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0)); ordered_input_dimensions[1] = - lhs_literal->shape().dimensions(dnums.spatial_dimensions(1)); + lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1)); ordered_kernel_dimensions[0] = rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); ordered_kernel_dimensions[1] = @@ -703,137 +718,4 @@ ReferenceUtil::ReduceToRowArray2D( return result; } -/* static */ std::unique_ptr> ReferenceUtil::PadArray2D( - const Array2D& operand, const PaddingConfig& padding, - const float pad) { - int64 in0 = operand.n1(); - int64 high_padding0 = padding.dimensions(0).edge_padding_high(); - int64 low_padding0 = padding.dimensions(0).edge_padding_low(); - int64 interior_padding0 = padding.dimensions(0).interior_padding(); - int64 out0 = - in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; - - int64 in1 = operand.n2(); - int64 high_padding1 = padding.dimensions(1).edge_padding_high(); - int64 low_padding1 = padding.dimensions(1).edge_padding_low(); - int64 interior_padding1 = padding.dimensions(1).interior_padding(); - int64 out1 = - in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; - - auto result = MakeUnique>(out0, out1); - result->Fill(pad); - int64 o0 = low_padding0; - for (int64 i0 = 0; i0 < in0; ++i0) { - int64 o1 = low_padding1; - for (int64 i1 = 0; i1 < in1; ++i1) { - if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) { - (*result)(o0, o1) = operand(i0, i1); - } - o1 += interior_padding1 + 1; - } - o0 += interior_padding0 + 1; - } - return result; -} - -/* static */ Array3D ReferenceUtil::PadArray3D( - const Array3D& operand, const PaddingConfig& padding, - const float pad) { - CHECK_EQ(padding.dimensions_size(), 3); - - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3()}; - std::vector pad_low(3); - std::vector pad_high(3); - std::vector pad_interior(3); - std::vector output_bounds(3); - for (int64 i = 0; i < 3; ++i) { - pad_low[i] = padding.dimensions(i).edge_padding_low(); - pad_high[i] = padding.dimensions(i).edge_padding_high(); - CHECK_LE(0, pad_low[i]); - CHECK_LE(0, pad_high[i]); - CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; - pad_interior[i] = padding.dimensions(i).interior_padding(); - - output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + - (input_bounds[i] - 1) * pad_interior[i]; - } - - Array3D result(output_bounds[0], output_bounds[1], output_bounds[2]); - std::vector indices = {0, 0, 0}; - for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { - for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { - for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { - float* value = &result(indices[0], indices[1], indices[2]); - bool value_padded = false; - for (int i = 0; i < 3; ++i) { - bool in_low_padding = indices[i] < pad_low[i]; - bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; - if (in_low_padding || in_high_padding) { - *value = pad; - value_padded = true; - } - if (pad_interior[i] && - (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { - *value = pad; - value_padded = true; - } - } - if (value_padded) { - continue; - } - *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), - (indices[1] - pad_low[1]) / (pad_interior[1] + 1), - (indices[2] - pad_low[2]) / (pad_interior[2] + 1)); - } - } - } - return result; -} - -/* static */ Array4D ReferenceUtil::PadArray4D( - const Array4D& operand, const PaddingConfig& padding, - const float pad) { - CHECK_EQ(padding.dimensions_size(), 4); - - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3(), operand.n4()}; - std::vector pad_low(4); - std::vector pad_high(4); - std::vector pad_interior(4); - std::vector output_bounds(4); - for (int64 i = 0; i < 4; ++i) { - pad_low[i] = padding.dimensions(i).edge_padding_low(); - pad_high[i] = padding.dimensions(i).edge_padding_high(); - CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; - pad_interior[i] = padding.dimensions(i).interior_padding(); - - output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + - (input_bounds[i] - 1) * pad_interior[i]; - } - - Array4D result(output_bounds[0], output_bounds[1], output_bounds[2], - output_bounds[3]); - result.Each([&](tensorflow::gtl::ArraySlice indices, float* value) { - for (int i = 0; i < 4; ++i) { - bool in_low_padding = indices[i] < pad_low[i]; - bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; - if (in_low_padding || in_high_padding) { - *value = pad; - return; - } - if (pad_interior[i] && - (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { - *value = pad; - return; - } - } - *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), - (indices[1] - pad_low[1]) / (pad_interior[1] + 1), - (indices[2] - pad_low[2]) / (pad_interior[2] + 1), - (indices[3] - pad_low[3]) / (pad_interior[3] + 1)); - }); - return result; -} - } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 2da17307817858eea60e868f4be1ab8138784385..58e1a844610678f64677838e93f0379b63f65d39 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -70,7 +70,7 @@ class ReferenceUtil { // dilation factors. static std::unique_ptr> ConvArray4DGeneralDimensionsDilated( const Array4D& lhs, const Array4D& rhs, - std::pair stride, Padding padding, + std::pair kernel_stride, Padding padding, std::pair lhs_dilation, std::pair rhs_dilation, ConvolutionDimensionNumbers dnums); @@ -184,6 +184,12 @@ class ReferenceUtil { const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + static std::unique_ptr> ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, @@ -486,19 +492,147 @@ class ReferenceUtil { } // Returns the result of a 2D pad on an input matrix. - static std::unique_ptr> PadArray2D( - const Array2D& operand, const PaddingConfig& padding, - const float pad); + template + static std::unique_ptr> PadArray2D( + const Array2D& operand, const PaddingConfig& padding, + const NativeT pad) { + int64 in0 = operand.n1(); + int64 high_padding0 = padding.dimensions(0).edge_padding_high(); + int64 low_padding0 = padding.dimensions(0).edge_padding_low(); + int64 interior_padding0 = padding.dimensions(0).interior_padding(); + int64 out0 = + in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; + + int64 in1 = operand.n2(); + int64 high_padding1 = padding.dimensions(1).edge_padding_high(); + int64 low_padding1 = padding.dimensions(1).edge_padding_low(); + int64 interior_padding1 = padding.dimensions(1).interior_padding(); + int64 out1 = + in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; + + auto result = MakeUnique>(out0, out1); + result->Fill(pad); + int64 o0 = low_padding0; + for (int64 i0 = 0; i0 < in0; ++i0) { + int64 o1 = low_padding1; + for (int64 i1 = 0; i1 < in1; ++i1) { + if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) { + (*result)(o0, o1) = operand(i0, i1); + } + o1 += interior_padding1 + 1; + } + o0 += interior_padding0 + 1; + } + return result; + } // Returns the result of a 3D pad on an input matrix. - static Array3D PadArray3D(const Array3D& operand, - const PaddingConfig& padding, - const float pad); + template + static Array3D PadArray3D(const Array3D& operand, + const PaddingConfig& padding, + const NativeT pad) { + CHECK_EQ(padding.dimensions_size(), 3); + + const std::vector input_bounds = {operand.n1(), operand.n2(), + operand.n3()}; + std::vector pad_low(3); + std::vector pad_high(3); + std::vector pad_interior(3); + std::vector output_bounds(3); + for (int64 i = 0; i < 3; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_LE(0, pad_low[i]); + CHECK_LE(0, pad_high[i]); + CHECK_LE(0, padding.dimensions(i).interior_padding()) + << "not implemented"; + pad_interior[i] = padding.dimensions(i).interior_padding(); + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + + (input_bounds[i] - 1) * pad_interior[i]; + } + + Array3D result(output_bounds[0], output_bounds[1], + output_bounds[2]); + std::vector indices = {0, 0, 0}; + for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { + for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { + for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { + NativeT* value = &result(indices[0], indices[1], indices[2]); + bool value_padded = false; + for (int i = 0; i < 3; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + value_padded = true; + } + if (pad_interior[i] && + (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { + *value = pad; + value_padded = true; + } + } + if (value_padded) { + continue; + } + *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), + (indices[1] - pad_low[1]) / (pad_interior[1] + 1), + (indices[2] - pad_low[2]) / (pad_interior[2] + 1)); + } + } + } + return result; + } // Returns the result of a 4D pad on an input array. - static Array4D PadArray4D(const Array4D& operand, - const PaddingConfig& padding, - const float pad); + template + static Array4D PadArray4D(const Array4D& operand, + const PaddingConfig& padding, + const NativeT pad) { + CHECK_EQ(padding.dimensions_size(), 4); + + const std::vector input_bounds = {operand.n1(), operand.n2(), + operand.n3(), operand.n4()}; + std::vector pad_low(4); + std::vector pad_high(4); + std::vector pad_interior(4); + std::vector output_bounds(4); + for (int64 i = 0; i < 4; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_LE(0, padding.dimensions(i).interior_padding()) + << "not implemented"; + pad_interior[i] = padding.dimensions(i).interior_padding(); + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + + (input_bounds[i] - 1) * pad_interior[i]; + } + + Array4D result(output_bounds[0], output_bounds[1], + output_bounds[2], output_bounds[3]); + result.Each( + [&](tensorflow::gtl::ArraySlice indices, NativeT* value) { + for (int i = 0; i < 4; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + return; + } + if (pad_interior[i] && + (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { + *value = pad; + return; + } + } + *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), + (indices[1] - pad_low[1]) / (pad_interior[1] + 1), + (indices[2] - pad_low[2]) / (pad_interior[2] + 1), + (indices[3] - pad_low[3]) / (pad_interior[3] + 1)); + }); + return result; + } // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, .... diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index eb6a71242ffa1499876b90f14f8a60ffdbdd069c..846ccdc83df900e3afedb6ababe07ebb1bd68f41 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -60,7 +60,9 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) { TEST_F(ReferenceUtilTest, MatmulArray2D) { Array2D rhs({ - {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f}, + {7.f, 8.f}, + {9.f, 10.f}, + {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); auto actual_literal = Literal::CreateR2FromArray2D(*result); @@ -326,8 +328,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { dimension_numbers.set_input_feature_dimension(0); dimension_numbers.set_output_batch_dimension(2); dimension_numbers.set_output_feature_dimension(0); - dimension_numbers.add_spatial_dimensions(1); - dimension_numbers.add_spatial_dimensions(3); + dimension_numbers.add_input_spatial_dimensions(1); + dimension_numbers.add_output_spatial_dimensions(1); + dimension_numbers.add_input_spatial_dimensions(3); + dimension_numbers.add_output_spatial_dimensions(3); dimension_numbers.set_kernel_output_feature_dimension(0); dimension_numbers.set_kernel_input_feature_dimension(2); dimension_numbers.add_kernel_spatial_dimensions(1); @@ -380,8 +384,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { dimension_numbers.set_input_feature_dimension(0); dimension_numbers.set_output_batch_dimension(2); dimension_numbers.set_output_feature_dimension(0); - dimension_numbers.add_spatial_dimensions(1); - dimension_numbers.add_spatial_dimensions(3); + dimension_numbers.add_input_spatial_dimensions(1); + dimension_numbers.add_output_spatial_dimensions(1); + dimension_numbers.add_input_spatial_dimensions(3); + dimension_numbers.add_output_spatial_dimensions(3); dimension_numbers.set_kernel_output_feature_dimension(0); dimension_numbers.set_kernel_input_feature_dimension(2); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 521fe411a4beed8b075568a41bce116bb528624f..c7432aacd18215d8c561b636a8ccc0da8118398c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -566,7 +566,6 @@ cc_library( hdrs = ["shaped_buffer.h"], deps = [ ":device_memory_allocator", - ":transfer_manager", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -630,6 +629,7 @@ cc_library( cc_library( name = "llvm_compiler", + srcs = ["llvm_compiler.cc"], hdrs = ["llvm_compiler.h"], deps = [ ":compiler", @@ -642,6 +642,7 @@ cc_library( srcs = ["transfer_manager.cc"], hdrs = ["transfer_manager.h"], deps = [ + ":shaped_buffer", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1142,6 +1143,22 @@ tf_cc_test( ], ) +cc_library( + name = "dot_decomposer", + srcs = ["dot_decomposer.cc"], + hdrs = ["dot_decomposer.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + cc_library( name = "tuple_simplifier", srcs = ["tuple_simplifier.cc"], @@ -1291,24 +1308,6 @@ cc_library( alwayslink = True, # Contains per-platform transfer manager registration ) -tf_cc_test( - name = "transfer_manager_test", - srcs = ["transfer_manager_test.cc"], - deps = [ - ":generic_transfer_manager", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service/cpu:cpu_transfer_manager", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", - ], -) - cc_library( name = "hlo_cost_analysis", srcs = ["hlo_cost_analysis.cc"], @@ -1321,6 +1320,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], ) @@ -1358,6 +1358,7 @@ cc_library( deps = [ ":hlo", ":hlo_cost_analysis", + ":hlo_profile_printer", ":human_readable_profile_builder", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -1366,6 +1367,18 @@ cc_library( ], ) +tf_cc_test( + name = "hlo_execution_profile_test", + srcs = ["hlo_execution_profile_test.cc"], + deps = [ + ":cpu_plugin", + ":hlo_cost_analysis", + ":hlo_execution_profile", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "hlo_computation_test", srcs = ["hlo_computation_test.cc"], @@ -1642,10 +1655,14 @@ cc_library( deps = [ ":buffer_liveness", ":hlo", + ":hlo_alias_analysis", + ":hlo_dce", + ":hlo_graph_dumper", + ":hlo_ordering", ":hlo_pass", ":liveness_util", ":logical_buffer", - ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1660,15 +1677,17 @@ tf_cc_test( deps = [ ":copy_insertion", ":hlo", + ":hlo_graph_dumper", ":hlo_matchers", - ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1778,7 +1797,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], ) @@ -1849,7 +1867,6 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], ) @@ -1888,6 +1905,22 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_element_type_converter", + srcs = ["hlo_element_type_converter.cc"], + hdrs = ["hlo_element_type_converter.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + cc_library( name = "device_memory_allocator", srcs = ["device_memory_allocator.cc"], @@ -1985,6 +2018,7 @@ cc_library( ":hlo", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -2158,6 +2192,16 @@ cc_library( ], ) +cc_library( + name = "hlo_profile_printer", + srcs = ["hlo_profile_printer.cc"], + hdrs = ["hlo_profile_printer.h"], + deps = [ + ":human_readable_profile_builder", + "//tensorflow/compiler/xla:types", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 35fe0d1a5192b93c0be47ecc1b1bdb753da792af..d7bf4f37af9f9ed872cfb39995109b8104042a33 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -46,9 +46,6 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; - // Returns whether operand is a literal with the given value. bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { return operand->opcode() == HloOpcode::kConstant && @@ -135,7 +132,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override; + Status HandleComplex(HloInstruction* complex) override; + Status HandleReal(HloInstruction* real) override; + Status HandleImag(HloInstruction* imag) override; Status HandleConvolution(HloInstruction* convolution) override; @@ -180,19 +180,46 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { static bool Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification, bool enable_conv_simplification); + bool enable_dot_strength_reduction, bool enable_conv_simplification); private: explicit AlgebraicSimplifierVisitor( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification, bool enable_conv_simplification) + bool enable_dot_strength_reduction, bool enable_conv_simplification) : computation_(computation), is_layout_sensitive_(is_layout_sensitive), valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_simplification_(enable_dot_simplification), + enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} + // Transforms Dots where at least one input is a vector or has a degenerate + // dimension and converts it into a multiply and reduce. This should enable + // more fusion than leaving the nodes as Dot operations. + StatusOr HandleDotStrengthReduction(HloInstruction* dot); + + // Reshapes an instruction to rank 1 if it is not already rank 1. + HloInstruction* Flatten(HloInstruction* hlo) { + if (ShapeUtil::Rank(hlo->shape()) == 1) { + return hlo; + } + return computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(hlo->shape().element_type(), + {ShapeUtil::ElementsIn(hlo->shape())}), + hlo)); + } + + // Helper method to perform and add reduction in a single dimension. + HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + HloInstruction* zero = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloComputation* AddReduce_computation = CreateScalarBinaryComputation( + computation_->parent(), F32, HloOpcode::kAdd); + Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); + return computation_->AddInstruction(HloInstruction::CreateReduce( + shape, hlo, zero, {dim}, AddReduce_computation)); + } + // Convenience method for replacing an instruction with a bitcast. void ReplaceWithBitcast(HloInstruction* instruction); @@ -252,6 +279,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + StatusOr OptimizeDotOfConcat(HloInstruction* dot); + StatusOr OptimizeDotOfConcatHelper( + const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -265,8 +297,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Callback used to determine if a bitcast is possible. AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; - // Disable dot simplication on platforms where it causes a slowdown. - bool enable_dot_simplification_; + // Disable dot strength reduction on platforms where it causes a slowdown. + bool enable_dot_strength_reduction_; // Disable convolution simplication on platforms where it causes a slowdown. bool enable_conv_simplification_; @@ -275,10 +307,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification, bool enable_conv_simplification) { + bool enable_dot_strength_reduction, bool enable_conv_simplification) { AlgebraicSimplifierVisitor visitor( computation, is_layout_sensitive, std::move(valid_bitcast_callback), - enable_dot_simplification, enable_conv_simplification); + enable_dot_strength_reduction, enable_conv_simplification); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -574,68 +606,72 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { - auto lhs = dot->mutable_operand(0); - auto rhs = dot->mutable_operand(1); - if (!enable_dot_simplification_) { - return Status::OK(); - } - // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or - // below. - if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || - ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { - return Status::OK(); - } - - // Replace a zero element dot with a broadcast of the constant 0. - if (ShapeUtil::HasZeroElements(dot->shape()) || - ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); - } - - // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). - if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { - auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot, - rhs->mutable_operand(0), lhs->mutable_operand(0))); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); - } +StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( + HloInstruction* dot) { + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + int64 lhs_collapsing_dim = + dot->dot_dimension_numbers().lhs_contracting_dimensions(0); + if (lhs->IsRank2Transpose()) { + lhs = lhs->mutable_operand(0); + lhs_collapsing_dim = 1 - lhs_collapsing_dim; + } + const int64 lhs_kept_dim = 1 - lhs_collapsing_dim; + + int64 rhs_collapsing_dim = + dot->dot_dimension_numbers().rhs_contracting_dimensions(0); + if (rhs->IsRank2Transpose()) { + rhs = rhs->mutable_operand(0); + rhs_collapsing_dim = 1 - rhs_collapsing_dim; + } + const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; + + auto reshape_if_necessary = [&](HloInstruction* hlo) { + if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { + return hlo; + } + return computation_->AddInstruction( + HloInstruction::CreateReshape(dot->shape(), hlo)); + }; - // Simplify outer product into multiply with implicit broadcasting. - // - // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) - if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) { - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, - lhs, rhs)); - } + auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, + int64 dim) { + return computation_->AddInstruction( + HloInstruction::CreateBroadcast(shape, hlo, {dim})); + }; - // The following graph transformations take Dots where at least one input is a - // vector or has a degenerate dimension and converts it into a multiply and - // reduce. This should enable more fusion than leaving the nodes as Dot - // operations. + auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { + return computation_->AddInstruction(HloInstruction::CreateBinary( + local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs)); + }; // Strength reduce dot(a[K] , b[K]) = // reshape(result.shape, // reduce_sum(multiply(a, b), {0})) if (ShapeUtil::Rank(rhs->shape()) == 1 && ShapeUtil::Rank(lhs->shape()) == 1) { - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, lhs, rhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, - {0}, add_reduce_computation)); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(AddReduce( + multiply(Flatten(lhs), Flatten(rhs)), 0)))); + return true; + } + + if (ShapeUtil::IsEffectiveScalar(rhs->shape()) && + ShapeUtil::IsEffectiveScalar(lhs->shape())) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs))))); + return true; + } + + // Simplify outer product into multiply with implicit broadcasting. + // + // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) + if (ShapeUtil::Rank(rhs->shape()) == 2 && + rhs->shape().dimensions(rhs_collapsing_dim) == 1) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), + broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); + return true; } // Strength reduce dot(a[1, K], b) = @@ -646,35 +682,21 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // ) // ) if (ShapeUtil::Rank(lhs->shape()) == 1 || - (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) { - auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(lhs->shape().element_type(), - {ShapeUtil::ElementsIn(lhs->shape())}), - lhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* reduce; + (ShapeUtil::Rank(lhs->shape()) == 2 && + lhs->shape().dimensions(lhs_kept_dim) == 1)) { if (ShapeUtil::Rank(rhs->shape()) == 1) { - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); - reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, - {0}, add_reduce_computation)); - } else { - new_lhs = computation_->AddInstruction( - HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0})); - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); - - reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), - {rhs->shape().dimensions(1)}), - multiply, zero, {0}, add_reduce_computation)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, + reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0)))); + return true; } - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary( + AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), + rhs_collapsing_dim), + rhs), + rhs_collapsing_dim)))); + return true; } // Strength reduce dot(a, b[K, 1]) = @@ -682,26 +704,208 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) // ) if (ShapeUtil::Rank(rhs->shape()) == 1 || - (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) { - auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(rhs->shape().element_type(), - {ShapeUtil::ElementsIn(rhs->shape())}), - rhs)); - new_rhs = computation_->AddInstruction( - HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1})); - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); + (ShapeUtil::Rank(rhs->shape()) == 2 && + rhs->shape().dimensions(rhs_kept_dim) == 1)) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(AddReduce( + multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), + lhs_collapsing_dim)), + lhs_collapsing_dim)))); + return true; + } + return false; +} + +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( + HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_contracting_dimensions_size() != 1 || + dnums.lhs_batch_dimensions_size() != 0) { + return nullptr; + } + + const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + + TF_ASSIGN_OR_RETURN( + HloInstruction * optimized_lhs_concat, + OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs, + rhs_contracting_dim, /*swapped=*/false)); + if (optimized_lhs_concat) { + return optimized_lhs_concat; + } + + return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs, + lhs_contracting_dim, /*swapped=*/true); +} + +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( + const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) { + bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && + lhs->concatenate_dimension() == lhs_contracting_dim && + rhs->opcode() == HloOpcode::kConstant; + if (!can_optimize) { + return nullptr; + } + + // We're replacing this: + // + // +-----+-----+-----+ +-------------------+ + // | | | | | | + // | | | | | R_0 | + // | | | | | | + // | | | | +-------------------+ + // | | | | | | + // | L_0 | L_1 | L_2 | * | R_1 | + // | | | | | | + // | | | | +-------------------+ + // | | | | | | + // | | | | | R_2 | + // | | | | | | + // +-----+-----+-----+ +-------------------+ + // + // with this: + // + // [Sum over i] + // + // +-----+ +-------------------+ + // | | | | + // | | * | R_i | + // | | | | + // | | +-------------------+ + // | | + // | L_i | + // | | + // | | + // | | + // | | + // | | + // +-----+ + // + // where the LHS is a concatenate operation (so we can "split" the LHS tensor + // for free) and the RHS is a constant tensor (and thus can be split at + // compile time). In the future, we may also want to do this when both the + // LHS and the RHS are concatenate operations that line up along the dimension + // being contracted over. + // + // We should be able to generalize this transform to work on a non-constant + // RHS when/if we have in-place slices or support input-fusing slices into + // Dots. + + // Dimension numbers for the new dot instructions we'll create (L_i * R_i in + // the diagram above). + DotDimensionNumbers new_dot_dnums; + new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim + : lhs_contracting_dim); + new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim + : rhs_contracting_dim); + + // Here we use the MKN notation, where the contracted dimension has K + // elements and the two non-contracted dimensions have M and N elements. + HloInstruction* add_result = nullptr; + int64 rhs_contracting_dim_offset = 0; + int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim); + for (HloInstruction* concat_op : lhs->operands()) { + int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim); + Shape rhs_slice_shape(rhs->shape()); + rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k); + + std::array start_indices; + start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset; + start_indices[1 - rhs_contracting_dim] = 0; + + std::array limit_indices; + limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k; + limit_indices[1 - rhs_contracting_dim] = n; + + HloInstruction* rhs_slice = + computation_->AddInstruction(HloInstruction::CreateSlice( + rhs_slice_shape, rhs, /*start_indices=*/start_indices, + /*limit_indices=*/limit_indices, /*strides=*/{1, 1})); + + // TODO(b/69062148): We can get rid of `swapped` once all backends support + // "non-canonical" contraction dimensions (that contracts dimension 1 of the + // LHS with dimension 0 of the RHS). But for now we keep the same + // contraction dimensions as the incoming dot operation to ensure the new + // dot operations can be lowered. + HloInstruction *new_dot_lhs, *new_dot_rhs; + if (swapped) { + new_dot_lhs = rhs_slice; + new_dot_rhs = concat_op; + } else { + new_dot_lhs = concat_op; + new_dot_rhs = rhs_slice; + } + + auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( + dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums)); + + if (add_result) { + add_result = computation_->AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, add_result, new_dot)); + } else { + add_result = new_dot; + } + + rhs_contracting_dim_offset += sub_k; + } + + return add_result; +} + +Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { + auto lhs = dot->mutable_operand(0); + auto rhs = dot->mutable_operand(1); + + // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or + // below. + if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || + ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { + return Status::OK(); + } + + // Replace a zero element dot with a broadcast of the constant 0. + if (ShapeUtil::HasZeroElements(dot->shape()) || + ShapeUtil::HasZeroElements(lhs->shape()) || + ShapeUtil::HasZeroElements(rhs->shape())) { auto zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), - {lhs->shape().dimensions(0)}), - multiply, zero, {1}, add_reduce_computation)); return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } + + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, + OptimizeDotOfConcat(dot)); + if (dot_of_concat_optimized) { + VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., " + "constant)...)"; + return ReplaceInstruction(dot, dot_of_concat_optimized); + } + + if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { + TF_ASSIGN_OR_RETURN(bool did_strength_reduction, + HandleDotStrengthReduction(dot)); + if (did_strength_reduction) { + return Status::OK(); + } + } + + // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). + if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), + rhs->mutable_operand(0), lhs->mutable_operand(0), + dot_dimension_numbers)); + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); + } + return Status::OK(); } @@ -947,6 +1151,18 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { return Status::OK(); } +// Complex(Real(c), Imag(c)) -> c +Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { + auto real = complex->mutable_operand(0); + auto imag = complex->mutable_operand(1); + if (real->opcode() == HloOpcode::kReal && + imag->opcode() == HloOpcode::kImag && + real->operand(0) == imag->operand(0)) { + return ReplaceInstruction(complex, real->mutable_operand(0)); + } + return Status::OK(); +} + // Real(Complex(r, i)) -> r Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { auto operand = real->mutable_operand(0); @@ -1096,9 +1312,15 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( Literal::One(rhs->shape().element_type()).CloneToUnique())); + + // Explicitly broadcast scalar 1 to the output shape, to avoid implicit + // broadcast in divide HLO as we are trying to eliminate implicit + // broadcasting at HLO level. + auto* broadcast_one = computation_->AddInstruction( + HloInstruction::CreateBroadcast(power->shape(), one, {})); return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, - one, lhs)); + broadcast_one, lhs)); } return Status::OK(); } @@ -1386,6 +1608,15 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( auto operand = reduce_window->mutable_operand(0); const Window& window = reduce_window->window(); auto function = reduce_window->to_apply(); + if (ShapeUtil::IsScalar(operand->shape())) { + TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape())); + return ReplaceWithNewInstruction( + reduce_window, + HloInstruction::CreateMap(reduce_window->shape(), + {operand, reduce_window->mutable_operand(1)}, + function)); + } + VLOG(10) << "Considering folding Pad: " << operand->ToString() << "\ninto reduce-window: " << reduce_window->ToString(); @@ -1587,8 +1818,11 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_rhs = add_bitcast(new_filter_shape, rhs); - auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( - dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto dot = computation_->AddInstruction(HloInstruction::CreateDot( + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } @@ -1676,7 +1910,7 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { for (auto* comp : module->MakeNonfusionComputations()) { if (AlgebraicSimplifierVisitor::Run( comp, is_layout_sensitive_, valid_bitcast_callback_, - enable_dot_simplification_, enable_conv_simplification_)) { + enable_dot_strength_reduction_, enable_conv_simplification_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index a9f476178c7af74c275a10de7727ea64e17d590f..43315f5cdc7afbe79039420320f4a0d0535e11f1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -40,11 +40,11 @@ class AlgebraicSimplifier : public HloPassInterface { // bitcasts. AlgebraicSimplifier(bool is_layout_sensitive, ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification = true, + bool enable_dot_strength_reduction = true, bool enable_conv_simplification = true) : is_layout_sensitive_(is_layout_sensitive), valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_simplification_(enable_dot_simplification), + enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} ~AlgebraicSimplifier() override = default; tensorflow::StringPiece name() const override { return "algsimp"; } @@ -58,7 +58,7 @@ class AlgebraicSimplifier : public HloPassInterface { ValidBitcastCallback valid_bitcast_callback_; // Enable dot simplication on platforms where it is profitable. - bool enable_dot_simplification_; + bool enable_dot_strength_reduction_; // Enable convolution simplication on platforms where it is profitable. bool enable_conv_simplification_; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index c06e330bc12ec73ae46b84505b34c16e3591aaa5..d0b659eec306464dae944cf7758b89f89d60974c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -371,6 +371,31 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { EXPECT_EQ(root, param0); } +// Test that complex(real(c), imag(c)) is simplified to c. +TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2c64, "param0")); + HloInstruction* real = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0)); + HloInstruction* imag = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0)); + HloInstruction* cplx = builder.AddInstruction( + HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, cplx); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + // Test that real(complex(r,i)) is simplified to r. TEST_F(AlgebraicSimplifierTest, RealOfComplex) { Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); @@ -736,8 +761,10 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Divide(op::Constant(), param0)); - EXPECT_EQ(root->operand(0)->literal().GetFirstElement(), 1); + EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast); + EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement(), + 1); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -1597,8 +1624,11 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { ConvolutionDimensionNumbers dnums; std::vector in_dims; int in_channel_idx = -1; - dnums.add_spatial_dimensions(-1); // filled in later - dnums.add_spatial_dimensions(-1); // filled in later + // filled in later + dnums.add_input_spatial_dimensions(-1); + dnums.add_output_spatial_dimensions(-1); + dnums.add_input_spatial_dimensions(-1); + dnums.add_output_spatial_dimensions(-1); for (int i = 0; i < strlen(options.dim_order); ++i) { char ch = options.dim_order[i]; if (ch == 'N') { @@ -1606,10 +1636,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { dnums.set_output_batch_dimension(i); in_dims.push_back(options.in_batch); } else if (ch == 'H') { - dnums.set_spatial_dimensions(0, i); + dnums.set_input_spatial_dimensions(0, i); + dnums.set_output_spatial_dimensions(0, i); in_dims.push_back(options.in_height); } else if (ch == 'W') { - dnums.set_spatial_dimensions(1, i); + dnums.set_input_spatial_dimensions(1, i); + dnums.set_output_spatial_dimensions(1, i); in_dims.push_back(options.in_width); } else if (ch == 'C') { dnums.set_input_feature_dimension(i); @@ -2106,8 +2138,10 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kDot, x, y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); std::unique_ptr dot_computation(builder.Build()); HloComputation::Builder call_builder(TestName() + ".Call"); @@ -2204,5 +2238,210 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { op::DynamicSlice(op::Parameter(), op::Parameter())); } +class DotStrengthReductionTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + ::testing::tuple> {}; +TEST_P(DotStrengthReductionTest, DotStrengthReduction) { + int m, k, n; + bool transpose_lhs, transpose_rhs; + std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam(); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k}); + HloComputation::Builder builder(TestName()); + + auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( + 0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs")); + if (transpose_lhs) { + lhs = builder.AddInstruction( + HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0})); + } + auto rhs = builder.AddInstruction(HloInstruction::CreateParameter( + 1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs")); + if (transpose_rhs) { + rhs = builder.AddInstruction( + HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0})); + } + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); + const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; + const bool computation_should_be_modified = + dot_should_be_transformed || (transpose_lhs && transpose_rhs); + EXPECT_EQ(changed, computation_should_be_modified); + bool has_no_dot = true; + for (const auto& hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kDot) { + has_no_dot = false; + break; + } + } + EXPECT_EQ(has_no_dot, dot_should_be_transformed); +} + +INSTANTIATE_TEST_CASE_P( + DotStrengthReductionTestInstantiation, DotStrengthReductionTest, + ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), + ::testing::Values(1, 2), ::testing::Bool(), + ::testing::Bool())); + +struct DotOfConcatTestSpec { + int64 m; + int64 k; + int64 n; +}; + +class DotOfConcatSimplificationTest + : public HloTestBase, + public ::testing::WithParamInterface {}; + +// Test that we transform +// dot(const, concat(A, B, C)) +// to +// add(dot(const_0, A), dot(const_1, B), dot(const_2, C)) +TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { + HloComputation::Builder builder(TestName()); + + DotOfConcatTestSpec spec = GetParam(); + + ASSERT_GE(spec.k, 3); + + int64 k0 = spec.k / 3; + int64 k1 = spec.k / 3; + int64 k2 = spec.k - k0 - k1; + + Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k))); + + Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n}); + Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n}); + Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n}); + + HloInstruction* rhs0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, rhs0_shape, "rhs0")); + HloInstruction* rhs1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, rhs1_shape, "rhs1")); + HloInstruction* rhs2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, rhs2_shape, "rhs2")); + + Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0)); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + ASSERT_TRUE(run_successful); + + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0)); + auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1)); + auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2)); + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2)); +} + +// Test that we transform +// dot(concat(A, B, C), const) +// to +// add(dot(A, const_0), dot(B, const_1), dot(C, const_2)) +TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { + HloComputation::Builder builder(TestName()); + + DotOfConcatTestSpec spec = GetParam(); + + ASSERT_GE(spec.k, 4); + + int64 k0 = spec.k / 4; + int64 k1 = spec.k / 4; + int64 k2 = spec.k / 4; + int64 k3 = spec.k - k0 - k1 - k2; + + Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0}); + Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1}); + Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2}); + Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3}); + + HloInstruction* lhs0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, lhs0_shape, "lhs0")); + HloInstruction* lhs1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, lhs1_shape, "lhs1")); + HloInstruction* lhs2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, lhs2_shape, "lhs2")); + HloInstruction* lhs3 = builder.AddInstruction( + HloInstruction::CreateParameter(3, lhs2_shape, "lhs3")); + + Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); + HloInstruction* lhs = + builder.AddInstruction(HloInstruction::CreateConcatenate( + lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1)); + + Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.m}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.m))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant())); + auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant())); + auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant())); + auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant())); + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2), + match_dot_3)); +} + +DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { + {/*m=*/3, /*k=*/9, /*n=*/3}, // + {/*m=*/3, /*k=*/20, /*n=*/3}, // + {/*m=*/1, /*k=*/18, /*n=*/5}, // + {/*m=*/20, /*k=*/20, /*n=*/1}, // + {/*m=*/1, /*k=*/16, /*n=*/1}, // +}; + +INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, + DotOfConcatSimplificationTest, + ::testing::ValuesIn(kDotOfConcatTestSpecs)); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 9abe30e3f371cc294c36c1dcd743224b11b0c4f5..05f2d062784147108a94ffb7bb0ca42ddfe4f010 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS + #include "tensorflow/compiler/xla/service/backend.h" #include #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc index abe881cd1a58a6173b9b93f10a7308d70106c889..2bbae25aee3db95406fd247deb788d2976207ba3 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc @@ -85,9 +85,9 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { HloOpcode opcode) { HloComputation::Builder b("scalar_computation"); auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); + 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs")); auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); + 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs")); auto scalar_op = b.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), opcode, scalar_lhs, scalar_rhs)); @@ -149,26 +149,41 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( if (!rewrite_training_op_) { return Status::OK(); } + + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + // Expand batch norm training into smaller HLO ops. HloInstruction* operand = batch_norm->mutable_operand(0); const Shape operand_shape = operand->shape(); + PrimitiveType ptype = operand_shape.element_type(); int64 feature_index = batch_norm->feature_index(); const int64 feature_count = operand_shape.dimensions(feature_index); const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - auto elements_per_feature = - computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(size_in_elements / feature_count))); + auto elements_per_feature_literal = + Literal::CreateR0(size_in_elements / feature_count); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); const Shape feature_shape = scale->shape(); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + auto zero_literal = Literal::CreateR0(0.0f); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); + auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -177,103 +192,110 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( } } - auto scale_broadcasted = computation_->AddInstruction( + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(F32, HloOpcode::kAdd); + GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // X^2. - auto operand_squared = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, operand, operand)); // Sum[X]. - auto sum = computation_->AddInstruction(HloInstruction::CreateReduce( - feature_shape, operand, zero, dimensions_without_feature, - add_reduce_computation)); + auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, + dimensions_without_feature, + add_reduce_computation)); // Sum[X^2]. - auto squared_sum = computation_->AddInstruction(HloInstruction::CreateReduce( + auto squared_sum = add(HloInstruction::CreateReduce( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); // Fuse two parallel reduces together to improve performance. - if (use_fusion_) { - auto tuple = computation_->AddInstruction( - HloInstruction::CreateTuple({sum, squared_sum})); + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); auto fused = computation_->CreateFusionInstruction( {tuple, sum, squared_sum, operand_squared}, HloInstruction::FusionKind::kInput); - sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - squared_sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + squared_sum = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // E[X]. - auto mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto square_mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); // E^2[X]. - auto mean_square = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean_square = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, mean, mean)); // Var[X]. - auto var = computation_->AddInstruction(HloInstruction::CreateBinary( + auto var = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); + auto neg_half_literal = Literal::CreateR0(-0.5f); + TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = computation_->AddInstruction( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kAdd, - scaled_normalized, offset_broadcasted)); - - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({shifted_normalized, mean, var}))); + auto shifted_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + + auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); + + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } @@ -286,6 +308,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( HloInstruction* operand = batch_norm->mutable_operand(0); const Shape operand_shape = operand->shape(); int64 feature_index = batch_norm->feature_index(); + PrimitiveType ptype = operand_shape.element_type(); HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -293,8 +316,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( HloInstruction* var = batch_norm->mutable_operand(4); const Shape feature_shape = scale->shape(); + auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; @@ -304,50 +329,69 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( } } - auto scale_broadcasted = computation_->AddInstruction( + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); + auto neg_half_literal = Literal::CreateR0(-0.5f); + TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted); + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + if (batch_norm->has_sharding()) { + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + shifted_normalized->set_sharding(batch_norm->sharding()); + } TF_CHECK_OK( ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); return Status::OK(); @@ -370,9 +414,17 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( if (!rewrite_grad_op_) { return Status::OK(); } + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); const Shape activation_shape = activation->shape(); + PrimitiveType ptype = activation_shape.element_type(); HloInstruction* scale = batch_norm->mutable_operand(1); const Shape feature_shape = scale->shape(); HloInstruction* mean = batch_norm->mutable_operand(2); @@ -383,18 +435,26 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); const int64 feature_count = activation_shape.dimensions(feature_index); - auto elements_per_feature = - computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(size_in_elements / feature_count))); - - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); - - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + auto elements_per_feature_literal = + Literal::CreateR0(size_in_elements / feature_count); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + + auto zero_literal = Literal::CreateR0(0.0f); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); + + auto neg_half_literal = Literal::CreateR0(-0.5f); + TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + + auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; @@ -404,126 +464,131 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( } } - auto scale_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, scale, {feature_index})); - auto variance_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, variance, {feature_index})); + auto scale_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, scale, {feature_index})); + auto variance_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, variance, {feature_index})); // E[X]. - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kAdd, variance, epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)), + neg_half)); + + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + feature_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, + epsilon)), + neg_half)); // X - E[X]. - auto activation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - activation, mean_broadcasted)); + auto activation_minus_mean = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); // Grad[Y] * (X - E[X]). - auto grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + auto grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, activation_minus_mean)); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(F32, HloOpcode::kAdd); + GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = - computation_->AddInstruction(HloInstruction::CreateReduce( + add(HloInstruction::CreateReduce( feature_shape, grad_output_times_activiation_minus_mean, zero, dimensions_without_feature, add_reduce_computation)); // Grad[beta] = Sum(Grad[Y]). - auto grad_beta = computation_->AddInstruction(HloInstruction::CreateReduce( + auto grad_beta = add(HloInstruction::CreateReduce( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_) { - auto tuple = computation_->AddInstruction(HloInstruction::CreateTuple( + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple( {sum_grad_output_times_activiation_minus_mean, grad_beta})); auto fused = computation_->CreateFusionInstruction( {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, HloInstruction::FusionKind::kInput); - sum_grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum_grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - grad_beta = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + grad_beta = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = computation_->AddInstruction(HloInstruction::CreateBinary( + auto grad_scale = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); // I2 = Sum(Grad[Y]) - auto I2 = computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, grad_beta, {feature_index})); + auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, + {feature_index})); // I3 = Sum(Grad[Y] * (X - E[X])) - auto I3 = computation_->AddInstruction(HloInstruction::CreateBroadcast( + auto i3 = add(HloInstruction::CreateBroadcast( activation_shape, sum_grad_output_times_activiation_minus_mean, {feature_index})); // I4 = (X - E[X]) * I3 - auto I4 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, I3, activation_minus_mean)); + auto i4 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); // I5 = I4 / (Var[X] + epsilon) - auto I5 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, I4, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kAdd, variance_broadcasted, epsilon)))); + auto i5 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, i4, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)))); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted)); - scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, - scale_times_rsqrt_var_add_epsilon, elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, + elements_per_feature)); - auto I1 = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto i1 = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, elements_per_feature)); // I6 = I1 - I2 - I5 - auto I6 = computation_->AddInstruction(HloInstruction::CreateBinary( + auto i6 = add(HloInstruction::CreateBinary( activation_shape, HloOpcode::kSubtract, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, I1, I2)), - I5)); + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, + i1, i2)), + i5)); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, I6)); + auto grad_activation = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6)); + auto tuple = + HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), activation_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}))); + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index b422b22df9cfbefb6611fcb229ed42e67fe3a0d8..7ece79d781acfaffc21d6a29e8a12e68622a1617 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -265,6 +265,42 @@ bool BufferAssignment::SharesSliceAtIndex( GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie(); } +bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, + const HloInstruction* hlo_b) const { + using SliceSet = + FlatSet; + // Gets the slices all of instr's subshapes. If any subshape doesn't have an + // assigned slice, returns the empty set. + auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { + SliceSet slices; + Status status = ShapeUtil::ForEachSubshapeWithStatus( + instr->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) { + auto shape_slices = GetAllSlices(instr, index); + if (shape_slices.empty()) { + return InvalidArgument("No slices assigned to part of instr."); + } + slices.insert(shape_slices.begin(), shape_slices.end()); + return Status::OK(); + }); + if (!status.ok()) { + return {}; + } + return slices; + }; + + SliceSet slices_a = collect_slices(hlo_a); + SliceSet slices_b = collect_slices(hlo_b); + // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e. + // didn't return the empty set) for both HLOs, and the two resulting sets of + // slices are disjoint. + return !slices_a.empty() && !slices_b.empty() && + std::none_of(slices_a.begin(), slices_a.end(), + [&](const BufferAllocation::Slice& slice) { + return slices_b.count(slice) > 0; + }); +} + StatusOr BufferAssignment::GetUniqueTopLevelOutputSlice() const { return GetUniqueTopLevelSlice( @@ -497,19 +533,19 @@ Status GatherComputationsByAllocationType( std::vector* global_computations) { // Create a worklist of computations paired with whether the allocation must // be thread-local. - std::deque> worklist; + std::deque> worklist; worklist.push_back(std::make_pair(module->entry_computation(), /*is_thread_local*/ false)); // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - FlatSet thread_local_set; - FlatSet global_set; + FlatSet thread_local_set; + FlatSet global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); worklist.pop_front(); - HloComputation* computation = worklist_front.first; + const HloComputation* computation = worklist_front.first; bool is_thread_local = worklist_front.second; bool in_thread_local_set = thread_local_set.count(computation) > 0; bool in_global_set = global_set.count(computation) > 0; @@ -545,6 +581,7 @@ Status GatherComputationsByAllocationType( instruction->called_computations()) { switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kWhile: // Call and while must be called from a computation with global // allocations as they may return references to buffers inside the @@ -653,7 +690,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { - HloComputation* entry_computation = + const HloComputation* entry_computation = assignment->module_->entry_computation(); for (auto param : entry_computation->parameter_instructions()) { for (auto& param_buffer : @@ -819,17 +856,6 @@ Status BufferAssigner::AssignBuffersForComputation( continue; } - if (instruction->opcode() == HloOpcode::kRecv) { - // Make sure that recv operations get a new unique allocation so that - // don't share their buffer with any other operations. - BufferAllocation* allocation = assignment->NewAllocation( - *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); - allocation_indices.push_back(allocation->index()); - VLOG(3) << "New allocation #" << allocation->index() - << " for recv: " << *buffer; - continue; - } - if (ShapeUtil::IsTuple(buffer->shape())) { // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend // assumes longer buffer liveness than indicated by the analysis. @@ -951,8 +977,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); if (run_whole_module_heap_simulation) { // Run the heap simulation over the whole module. This reduces memory usage, - // since buffers for kCall and kWhile sub-computations are only live for the - // duration of their calling instructions. + // since buffers for kCall, kWhile, and kConditional sub-computations are + // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; SequentialHloOrdering::HloModuleSequence module_sequence; FlatSet all_buffers_to_assign; @@ -1240,7 +1266,6 @@ const LogicalBuffer* AddBufferToColocatedSet( // CopyInsertion ensures root points-to set is unambiguous and distinct. const auto& points_to = points_to_analysis.GetPointsToSet(instruction); DCHECK(!points_to.IsAmbiguous()); - DCHECK(points_to.IsDistinct()); colocated_set->push_back(points_to.element(index)[0]); return colocated_set->back(); } @@ -1248,7 +1273,8 @@ const LogicalBuffer* AddBufferToColocatedSet( } // namespace // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated -// in the same allocation (currently just supports kWhile and kCall). +// in the same allocation (currently just supports kWhile, kCall, and +// kConditional). void BufferAssigner::BuildColocatedBufferSets( const HloModule* module, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, @@ -1312,6 +1338,26 @@ void BufferAssigner::BuildColocatedBufferSets( &colocated_set); AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); + } else if (opcode == HloOpcode::kConditional) { + const HloInstruction* conditional_hlo = instruction; + ShapeUtil::ForEachSubshape( + conditional_hlo->shape(), + [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( + const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector colocated_set; + // Add conditional.result. + AddBufferToColocatedSet(conditional_hlo, index, + points_to_analysis, &colocated_set); + // Add conditional.true_computation.root. + AddBufferToColocatedSet( + conditional_hlo->true_computation()->root_instruction(), + index, points_to_analysis, &colocated_set); + // Add conditional.false_computation.root. + AddBufferToColocatedSet( + conditional_hlo->false_computation()->root_instruction(), + index, points_to_analysis, &colocated_set); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + }); } } } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 08a53af8baa3f250919517c87c023c329b129024..08a40bfeb2a2a78c25805308e73154c6cc667f21 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -327,6 +327,12 @@ class BufferAssignment { return SharesSliceAtIndex(hlo_a, {}, hlo_b, {}); } + // Returns true if hlo_a and hlo_b both have at least one buffer assigned for + // their top-level and each of their nested shape indices, and if hlo_a's + // buffers are all different from hlo_b's buffers. + bool HaveDisjointSlices(const HloInstruction* hlo_a, + const HloInstruction* hlo_b) const; + // Returns the underlying points-to analysis used for this assignment. const TuplePointsToAnalysis& points_to_analysis() const { return liveness_->points_to_analysis(); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 89410f42bd7b5fa8f9b380c868fcd4fedb54576c..6fc9d783f1b34de8c0f93c6aa342591891d08eaf 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -85,7 +85,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, MakeUnique(module), + module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); @@ -94,7 +94,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunColoredBufferAssignment( HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { return BufferAssigner::Run( - module, MakeUnique(module), + module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, false, std::move(colorer)) @@ -166,6 +166,15 @@ class BufferAssignmentTest : public HloTestBase { return builder.Build(); } + std::unique_ptr BuildR0F32UnaryOpComputation( + HloOpcode opcode, const string& name) { + auto builder = HloComputation::Builder(name); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param)); + return builder.Build(); + } + // Verifies that the given instruction hlo has a valid input buffer assigned, // i.e., the parameter number matches the op's. const BufferAllocation& GetAssignedInputAllocation( @@ -740,6 +749,56 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { << " instructions; total buffer size " << size0 + sizec + sizeb; } +TEST_F(BufferAssignmentTest, ExampleConditional) { + auto module = CreateNewModule(); + auto true_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil")); + auto false_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor")); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + auto const2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.4f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + r0f32_, pred, const1, true_computation, const2, false_computation)); + module->AddEntryComputation(builder.Build()); + + const std::vector conditional_instrs = + GetInstructions(conditional); + const std::vector true_instrs = + GetInstructions(true_computation->root_instruction()); + const std::vector false_instrs = + GetInstructions(false_computation->root_instruction()); + EXPECT_EQ(4, conditional_instrs.size()); + EXPECT_EQ(2, true_instrs.size()); + EXPECT_EQ(2, false_instrs.size()); + + auto buffers = RunBufferAssignment(module.get()); + ValidateBuffers(conditional_instrs, *buffers); + ValidateBuffers(true_instrs, *buffers); + ValidateBuffers(false_instrs, *buffers); + + EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers)) + << "Should be reuse between conditional and true computation."; + EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers)) + << "Should be reuse between conditional and false computation."; + EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers)) + << "Should be reuse between true and false computations."; + + const BufferAllocation& conditional_buffer = + GetTopLevelAllocation(*buffers, conditional); + const BufferAllocation& true_buffer = + GetTopLevelAllocation(*buffers, true_computation->root_instruction()); + const BufferAllocation& false_buffer = + GetTopLevelAllocation(*buffers, false_computation->root_instruction()); + EXPECT_EQ(conditional_buffer.size(), true_buffer.size()); + EXPECT_EQ(conditional_buffer.size(), false_buffer.size()); +} + TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); @@ -1360,10 +1419,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { HloInstruction::CreateParameter(1, shape_3x4, "param_b")); auto param_c = builder.AddInstruction( HloInstruction::CreateParameter(2, shape_4x4, "param_c")); - auto dot_ab = builder.AddInstruction(HloInstruction::CreateBinary( - shape_2x4, HloOpcode::kDot, param_a, param_b)); - auto dot_bc = builder.AddInstruction(HloInstruction::CreateBinary( - shape_3x4, HloOpcode::kDot, param_b, param_c)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot_ab = builder.AddInstruction( + HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); + auto dot_bc = builder.AddInstruction( + HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); builder.AddInstruction( HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1)); @@ -1448,7 +1510,7 @@ class WhileBufferAssignmentTest : public HloTestBase { auto sequence = CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, MakeUnique(module, sequence), + module, xla::MakeUnique(module, sequence), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); @@ -1469,7 +1531,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1526,7 +1588,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1538,8 +1600,6 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); - auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -1556,10 +1616,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { auto body1 = module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); - auto tuple1 = builder.AddInstruction( - HloInstruction::CreateTuple({input0, weights0, output1})); auto while1 = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); module->AddEntryComputation(builder.Build()); RunCopyInsertion(module.get()); @@ -1575,7 +1633,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -1640,7 +1698,7 @@ static bool IsPostOrderTraversal( } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -1676,11 +1734,14 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto while1 = builder.AddInstruction( HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1)); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while0, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while1, 1)); auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( - while0->shape(), HloOpcode::kAdd, while0, while1)); - module->AddEntryComputation(builder.Build()); + while0->shape(), HloOpcode::kAdd, gte0, gte1)); - RunCopyInsertion(module.get()); + module->AddEntryComputation(builder.Build()); { FlattenCallGraph flatten; @@ -1688,84 +1749,35 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { EXPECT_TRUE(result); } + RunCopyInsertion(module.get()); + auto sequence = CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo sequence for the // root computation, so we overwrite that entry with a manually // crafted sequence. - std::vector sequence_for_buffer_assigment = { - input1, weights1, one, output1, tuple1, while1, input0, - weights0, zero, output0, tuple0, while0, root_add}; + sequence[module->entry_computation()] = { + input1, weights1, one, output1, while1->operand(0), while1, + input0, weights0, zero, output0, while0->operand(0), while0, + gte0, gte1, root_add}; // If this ASSERT_TRUE fails, we constructed a bogus sequence above // and this test itself is buggy. - ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment)); - - sequence[module->entry_computation()] = - std::move(sequence_for_buffer_assigment); + ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); auto assignment = BufferAssigner::Run( module.get(), - MakeUnique(module.get(), sequence), ByteSizeOf, - [](LogicalBuffer::Color) { return 1; }) + xla::MakeUnique(module.get(), sequence), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } -// Test buffer assignment for while nodes with multiple uses. -// TODO(b/37245345): Fix buffer assignment for this case. -TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { - auto module = MakeUnique(TestName()); - auto builder = HloComputation::Builder(TestName()); - - auto input0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape_, "input0")); - auto weights0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, data_shape_, "weights0")); - - auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); - auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); - - auto cond0 = - module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); - auto body0 = - module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); - - auto tuple0 = builder.AddInstruction( - HloInstruction::CreateTuple({input0, weights0, output0})); - auto while0 = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); - auto while1 = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, while0)); - - auto get0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); - auto get1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); - builder.AddInstruction( - HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, get0, get1)); - module->AddEntryComputation(builder.Build()); - - RunCopyInsertion(module.get()); - - { - FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); - EXPECT_TRUE(result); - } - - auto assignment = RunBufferAssignment(module.get()); - - EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); -} - TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 56600b583803e23324db778959de620440fce5cf..13825fe05bb1b98045f1a3dac3d7272a2d1151fb 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -120,7 +120,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); @@ -167,10 +167,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = BufferLiveness::Run( - module.get(), - MakeUnique(module.get(), sequence)) - .ConsumeValueOrDie(); + auto liveness = + BufferLiveness::Run(module.get(), xla::MakeUnique( + module.get(), sequence)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -216,7 +216,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -250,7 +250,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -294,7 +294,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { std::vector order = {param, negate, exp, add}; module_sequence.emplace(computation, order); auto liveness = - BufferLiveness::Run(module.get(), MakeUnique( + BufferLiveness::Run(module.get(), xla::MakeUnique( module.get(), module_sequence)) .ConsumeValueOrDie(); @@ -334,7 +334,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // All buffers should be live out except the param @@ -370,7 +370,7 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Buffers in different computations should always interfere. @@ -409,7 +409,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Only the element buffers of the tuple constant which are pointed to by @@ -474,7 +474,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -536,7 +536,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -624,8 +624,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. @@ -736,8 +736,8 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 1adecdb939cb2c1259003d3be2c90b5a299b0f30..13eb02ca012f44b2b5ed7c6f5becb7d54b07c33c 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -54,6 +54,7 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) { CallContext GetInstructionCallContext(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; case HloOpcode::kMap: diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 0395ea8c8b52315f7ca2221f412750ebadda2dd8..1ea7d538cd515c3098b6a1f03c6146d288330406 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -34,12 +34,13 @@ using ::testing::UnorderedElementsAre; class CallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. - std::unique_ptr MakeScalarComputation() { + std::unique_ptr MakeScalarComputation( + HloOpcode opcode = HloOpcode::kNegate) { HloComputation::Builder builder(TestName() + ".ScalarComputation"); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); builder.AddInstruction( - HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0)); + HloInstruction::CreateUnary(kScalarShape, opcode, param0)); return builder.Build(); } @@ -236,6 +237,54 @@ TEST_F(CallGraphTest, ContextBothComputations) { EXPECT_EQ(CallContext::kBoth, sub_node.context()); } +TEST_F(CallGraphTest, ComputationWithConditional) { + // Test a call graph of a module with a conditional. + auto module = CreateNewModule(); + HloComputation* true_computation = + module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil)); + HloComputation* false_computation = + module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor)); + + HloComputation::Builder builder(TestName()); + HloInstruction* pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction* const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + HloInstruction* const2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.6f))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + kScalarShape, pred, const1, true_computation, const2, + false_computation)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + std::unique_ptr call_graph = CallGraph::Build(module.get()); + + EXPECT_EQ(3, call_graph->nodes().size()); + + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(1, entry_node.callsites().size()); + + const CallSite& conditional_callsite = entry_node.callsites()[0]; + EXPECT_EQ(conditional, conditional_callsite.instruction()); + EXPECT_THAT(conditional_callsite.called_computations(), + UnorderedElementsAre(true_computation, false_computation)); + EXPECT_EQ(CallContext::kSequential, conditional_callsite.context()); + EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite); + + const CallGraphNode& true_node = call_graph->GetNode(true_computation); + EXPECT_TRUE(true_node.callees().empty()); + EXPECT_EQ(1, true_node.callers().size()); + EXPECT_EQ(entry_computation, true_node.callers()[0]); + + const CallGraphNode& false_node = call_graph->GetNode(false_computation); + EXPECT_TRUE(false_node.callees().empty()); + EXPECT_EQ(1, false_node.callers().size()); + EXPECT_EQ(entry_computation, false_node.callers()[0]); +} + TEST_F(CallGraphTest, ComplexGraph) { // Test a call graph of a module with several computation called in various // contexts. The call graph looks like: diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 3b1900428af1863c73efe67c27061d979557b3a4..e2e9d2a0c048fec6c6ffbeef1223ae0e6aef50d1 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -27,14 +27,8 @@ namespace se = ::perftools::gputools; namespace xla { -/* static */ tensorflow::mutex* Compiler::platform_compiler_mutex_; - -/* static */ void Compiler::LazyInitMutex() { - static std::once_flag mutex_init_flag; - std::call_once(mutex_init_flag, []() { - Compiler::platform_compiler_mutex_ = new tensorflow::mutex; - }); -} +/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* @@ -55,8 +49,7 @@ Compiler::GetPlatformCompilers() { /* static */ void Compiler::RegisterCompilerFactory( se::Platform::Id platform_id, std::function()> compiler_factory) { - LazyInitMutex(); - tensorflow::mutex_lock lock(*platform_compiler_mutex_); + tensorflow::mutex_lock lock(platform_compiler_mutex_); auto* factories = GetPlatformCompilerFactories(); CHECK(factories->find(platform_id) == factories->end()) << "Compiler factory already registered for platform"; @@ -65,8 +58,7 @@ Compiler::GetPlatformCompilers() { /* static */ StatusOr Compiler::GetForPlatform( const se::Platform* platform) { - LazyInitMutex(); - tensorflow::mutex_lock lock(*platform_compiler_mutex_); + tensorflow::mutex_lock lock(platform_compiler_mutex_); auto* compilers = GetPlatformCompilers(); // See if we already instantiated a compiler for this platform. diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 4c2d9600d909e82dcb62f508a10445c08c1cdee6..fc67330f5cbdbcb0d1a259d284599916a908d1fe 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -97,21 +97,32 @@ class Compiler { // Returns the ID of the platform that this compiler targets. virtual perftools::gputools::Platform::Id PlatformId() const = 0; + // Runs Hlo passes to optimize the given Hlo module, returns the optimized + // module. + virtual StatusOr> RunHloPasses( + std::unique_ptr module, + perftools::gputools::StreamExecutor* executor) = 0; + // Compiles the HLO module for execution on a device given by the executor, - // and returns an executable object or an error status. Takes ownership of the - // HLO module and is free to transform it. + // and returns an executable object or an error status. No HLO passes are + // applied to module. Generally a module should be passed through RunHloPasses + // prior to calling this method because the some HLO passes are required for + // correctness. Takes ownership of the HLO module and is free to transform it. // // The compiler may optionally specialize to the individual device // (not just type of device) indicated by the executor. // // Use the overload below to compile computations that run in parallel. - virtual StatusOr> Compile( + virtual StatusOr> RunBackend( std::unique_ptr module, perftools::gputools::StreamExecutor* executor) = 0; // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. + // + // TODO(b/68666782): Remove this method after adding support for multiple + // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( std::vector> modules, std::vector> @@ -157,8 +168,7 @@ class Compiler { private: // Mutex that guards the platform-compiler map. - static tensorflow::mutex* platform_compiler_mutex_; - static void LazyInitMutex(); + static tensorflow::mutex platform_compiler_mutex_; // Map from platform kind to compiler factory. static std::map* diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index cdfa30dd9a7b6a5b9e58087491a9d99caaa1b998..657fba6b6231104bf47f9dec80f7cd36a0ba3efd 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -52,6 +52,12 @@ Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { /* static */ StatusOr> DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count()); + if (proto.replica_count() <= 0 || proto.computation_count() <= 0) { + return InvalidArgument( + "Invalid device assignment topology: replica_count=%d, " + "computation_count=%d", + proto.replica_count(), proto.computation_count()); + } auto assignment = MakeUnique(proto.replica_count(), proto.computation_count()); for (int computation = 0; computation < proto.computation_count(); @@ -94,7 +100,7 @@ StatusOr ComputationPlacer::AssignDevices( se::Platform::Id platform_id, ComputationPlacerCreationFunction creation_function) { tensorflow::mutex_lock lock( - *ComputationPlacer::platform_computation_placer_mutex()); + ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); CHECK(computation_placers->find(platform_id) == computation_placers->end()); (*computation_placers)[platform_id].creation_function = creation_function; @@ -103,7 +109,7 @@ StatusOr ComputationPlacer::AssignDevices( /* static */ StatusOr ComputationPlacer::GetForPlatform( const se::Platform* platform) { tensorflow::mutex_lock lock( - *ComputationPlacer::platform_computation_placer_mutex()); + ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); auto it = computation_placers->find(platform->id()); @@ -122,11 +128,9 @@ StatusOr ComputationPlacer::AssignDevices( return it->second.placer.get(); } -/* static */ tensorflow::mutex* -ComputationPlacer::platform_computation_placer_mutex() { - static tensorflow::mutex* m = new tensorflow::mutex; - return m; -} +/* static */ tensorflow::mutex + ComputationPlacer::platform_computation_placer_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 7d9abcd100dd9e878da885110bc1bd1ac65e3f84..737ccabaa7a61931b6e2787f75b02857562d4820 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -89,11 +89,8 @@ class ComputationPlacer { const perftools::gputools::Platform* platform); private: - // Routine that returns the mutex that guards the platform-to-computation - // placer map. Done as a routine to ensure correct initialization ordering, - // since RegisterComputationPlacer can be called during program initialization - // time. - static tensorflow::mutex* platform_computation_placer_mutex(); + // The mutex that guards the platform-to-computation placer map. + static tensorflow::mutex platform_computation_placer_mutex_; // State kept for each kind of ComputationPlacer. Registration functions set // up creation_function, and then we use that to lazily create "placer" the diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 0453a698a09b740d68b35258ede7c537fcf290d4..cd983bc03e993caed883916de01d75dffdbc4bab 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,15 +15,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" -#include - +#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -31,597 +33,1174 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + namespace { -using tensorflow::gtl::FlatMap; -using tensorflow::gtl::FlatSet; +bool IsEntryParameterValue(const HloValue& value) { + const HloComputation* computation = value.defining_instruction()->parent(); + return value.defining_instruction()->opcode() == HloOpcode::kParameter && + computation == computation->parent()->entry_computation(); +} + +bool IsConstantValue(const HloValue& value) { + return value.defining_instruction()->opcode() == HloOpcode::kConstant; +} + +bool ValueIsReadOnly(const HloValue& value) { + return IsConstantValue(value) || IsEntryParameterValue(value); +} -// InstructionCopier encapsulates indices at which to copy 'instruction'. -// All 'instruction' users in 'copy_users' are updated to use the copy. +// Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in +// 'indices_to_copy'. Add control edges from the respective kCopy instructions +// in deep copy of 'from' to the respective kCopy instruction in the deep copy +// of 'to'. // -// Instruction copies are generated in two phases: -// 1) Recording buffer indices at which 'instruction' requires copies (i.e. -// setting 'indices_to_copy_[index]'=true). -// 2) Inserting kCopy instructions based on indices recorded in phase 1). -// *) Array instructions are copied by inserting a single kCopy instruction. -// *) Tuple-shaped instructions are copied by recursively expanding tuples -// (and tuple-shaped elements), and inserting kCopy instructions for any -// tuple elements which require a copy. As the recursion unwinds, new tuple -// instructions are added to gather the copied (and uncopied) references -// into the output tuple (i.e. the copy of the tuple-shaped instruction). +// Requirements: 'from' and 'to' must have compatible shapes. // -// Example two-element tuple with one element that needs a copy: +// For example, suppose 'from' and 'to' are two-element tuples where index 0 is +// the only index to copy. Prior to deep-copying we have: // -// original-instruction -// / \ -// GTE(0) GTE(1) -// | | -// Copy | -// \ / -// Tuple // copied-instruction // -// As an optimization, if the original instruction is itself a Tuple -// instruction, we elide the unnecessary extra GTE and Tuple instructions, -// and just insert the copy into a new Tuple instruction, with control -// dependencies to ensure the copy occurs after any possible interference. -class InstructionCopier { - public: - InstructionCopier(HloInstruction* instruction, - const std::vector& copy_users) - : instruction_(instruction), - copy_users_(copy_users), - indices_to_copy_(instruction->shape()), - control_predecessors_(instruction->shape()) {} - - // Sets indices that are read-only, and thus do not need to be copied. - void SetReadOnlyIndices(const ShapeTree& read_only_indices) { - read_only_indices_ = read_only_indices; - } +// 'from' +// | +// ... +// | +// 'to' +// +// DeepCopyAndAddControlEdges produces: +// +// 'from' +// / \ +// GTE GTE +// | | +// Copy | +// / \ / +// | Tuple +// | | +// ctrl ... +// edge | +// | | +// | 'to' +// | / \ +// | GTE GTE +// \ | | +// Copy | +// \ / +// Tuple +// +StatusOr> +DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, + const ShapeTree& indices_to_copy) { + DCHECK(ShapeUtil::Compatible(from->shape(), to->shape())); + // to/from_copy_tree hold the kCopy instruction produces by the deep + // copies. Elements which are not copied (indices_to_copy.element(index) == + // false) have nullptr at that index. + ShapeTree from_copy_tree(from->shape(), + /*init_value=*/nullptr); + TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy, + from->parent()->DeepCopyInstruction( + from, &indices_to_copy, &from_copy_tree)); - // Sets copy overrides, which are copy instructions to use at each index. This - // is used to share a single copy of read-only entry parameters and constants - // between multiple While loops. - void SetCopyOverrides(const ShapeTree& copy_overrides) { - copy_overrides_ = copy_overrides; + ShapeTree to_copy_tree(to->shape(), /*init_value=*/nullptr); + TF_ASSIGN_OR_RETURN( + HloInstruction * to_deep_copy, + to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree)); + + // Add control edges between the respective kCopy instructions. + for (const auto& pair : from_copy_tree) { + const ShapeIndex& index = pair.first; + HloInstruction* from_copy = pair.second; + HloInstruction* to_copy = to_copy_tree.element(index); + if (from_copy == nullptr) { + TF_RET_CHECK(to_copy == nullptr); + continue; + } + TF_RET_CHECK(to_copy != nullptr); + TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy)); } - // Returns true if all recorded indices are false (returns true otherwise). - bool HasAllIndicesFalse() const; + return std::make_pair(from_deep_copy, to_deep_copy); +} - // Records instruction buffer indices which point-to a Parameter or Constant. - Status RecordIndicesWhichPointToParamOrConstant( - const TuplePointsToAnalysis& points_to_analysis); +// Compute the indices of the loop state which need copies in order to avoid +// live range interference. Generally, an element in the loop state does not +// need to be copied if the element is passed through transparently through the +// body. +// +// Returns whether any indices need to be copied. +bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, + const HloInstruction* xla_while, + ShapeTree* indices_to_copy) { + DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape())); - // Records instruction buffer indices to copy which are necessary to ensure: - // *) PointsToSet of 'instruction_' is unambiguous and distinct. - // *) No liveness interference between 'instruction_' and 'other_instruction'. - // - // If 'read_only_indices_out' is non-null, read-only indices are set to true. - Status RecordIndicesToCopyForColocatingBuffers( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out); + bool any_copies = false; + const HloInstruction* init = xla_while->operand(0); + for (auto& pair : *indices_to_copy) { + const ShapeIndex& index = pair.first; + bool& should_copy = pair.second; + // If there is any ambiguity, then loop state must be copied. + if (dataflow.GetValueSet(init, index).values().size() > 1 || + dataflow.GetValueSet(xla_while, index).values().size() > 1) { + should_copy = true; + } else { + // If the output of the while instruction is not the same as the init + // value of the while, then this element is not passed through the body + // transparently and must be copied. + should_copy = dataflow.GetUniqueValueAt(xla_while, index) != + dataflow.GetUniqueValueAt(init, index); + } + any_copies |= should_copy; + } + return any_copies; +} - // Records control predecessors to add for inserted copy instructions. - // 'parameter' must have the same shape as the instruction that will be - // copied, and must define all buffers in the shape. Control predecessors are - // only recorded for indices that have already been marked for copying. - Status RecordControlPredecessors( - const TuplePointsToAnalysis& points_to_analysis, - HloInstruction* parameter); +// Add kCopy instructions around the given kWhile instruction to eliminate any +// possible live range interference of HLO values assuming a dependency-based +// ordering (HloDependencyOrdering). Copies are added conservatively. There +// likely are copies which are not strictly necessary, but there are removed +// later in the pass via CopyRemover. +// +// +// Elements (each ShapeIndex) in the loop state are considered independently. A +// copy is added to each element of the loop state which is modified in the +// while body. For each such element, a total of three kCopy instructions are +// added at following locations: +// +// (1) The init value is copied before the kWhile instruction. Before: +// +// (Init) +// | +// kWhile +// | +// ... +// +// After: +// +// (Init) +// | +// kCopy +// | +// kWhile +// | +// ... +// +// This copy is necessary in case the init value is simultaneously live +// with the kWhile. +// +// (2) Copies are added to the parameter and root of the while body +// computation. Before: +// +// kParameter +// | +// ... +// | +// (body root) +// +// After: +// +// kParameter +// | +// kCopy ----------+ +// | | +// ... ctrl +// | edge +// (body root) | +// | | +// kCopy <---------+ +// +// The root kCopy becomes the new root of the computation. Both copies are +// necessary to any potential interference between the parameter value and +// the root value. The control edge prevents potential interference +// between the copies themselves. +// +// If the loop state is a tuple then the above kCopy instructions are a deep +// copy constructed of kCopy, KGetTupleElement, and kTuple instruction as +// constructed by HloInstruction::DeepCopyInstruction. +Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, + HloInstruction* xla_while) { + VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name(); + TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile); - // Inserts copies of 'instruction' buffers at indices in 'indices_to_copy', - // and replaces all uses for instructions in 'copy_users_' with copy. - // Returns the instruction which is a copy 'instruction'. - HloInstruction* Copy(); + ShapeTree indices_to_copy(xla_while->shape()); + if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while, + &indices_to_copy)) { + VLOG(2) << "No copies necessary for kWhile instruction " + << xla_while->name(); + return Status::OK(); + } - HloInstruction* instruction() { return instruction_; } + VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:"; + for (auto& pair : indices_to_copy) { + if (pair.second) { + VLOG(2) << " " << pair.first; + } + } - const std::vector& copy_users() const { return copy_users_; } + // Deep copy init. + HloInstruction* while_init = xla_while->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * while_init_copy, + xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy)); + TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy)); - private: - // Does the given index represent a read-only buffer? - bool IsReadOnlyIndex(const ShapeIndex& index) const { - return !ShapeUtil::IsNil(read_only_indices_.shape()) && - read_only_indices_.element(index); - } + // Deep copy the parameter and the root. Extend a control edge from the copy + // of the parameter value to the corresponding copy value of the root. + HloComputation* body = xla_while->while_body(); + HloInstruction* param = body->parameter_instruction(0); + HloInstruction* root = body->root_instruction(); - // Returns the copy override at the given index, or nullptr. - HloInstruction* GetCopyOverride(const ShapeIndex& index) const { - return ShapeUtil::IsNil(copy_overrides_.shape()) - ? nullptr - : copy_overrides_.element(index); - } + // If param is the root then all indices should have been passed through the + // while body and we should have returned early above. + TF_RET_CHECK(param != root); - // Records instruction buffer indices which have ambiguous or non-distinct - // points-to sets. - Status RecordAmbiguousOrNonDistinctIndices( - const TuplePointsToAnalysis& points_to_analysis); + // Copy users before making a deep copy of the parameter as the deep copy + // will create new users of the parameter (eg, the GTE instructions of the + // deep copy). + std::vector param_users = param->users(); - // Records instruction buffer indices which have interfering live ranges - // with 'other_instruction' buffers at same index. - Status RecordIndicesWhichInterfereWithOtherInstruction( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out); + ShapeIndex current_index; + TF_ASSIGN_OR_RETURN(auto pair, + DeepCopyAndAddControlEdges(param, root, indices_to_copy)); - // Recursively inserts copies of 'instruction' tuple elements at indices - // specified in 'indices_to_copy', and returns the copy of 'instruction'. - HloInstruction* CopyTuple(HloInstruction* instruction, ShapeIndex* index); + HloInstruction* param_copy = pair.first; + HloInstruction* root_copy = pair.second; - void RecordIndex(const ShapeIndex& index) { - *indices_to_copy_.mutable_element(index) = true; + for (HloInstruction* user : param_users) { + TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy)); } - HloInstruction* instruction_; - const std::vector copy_users_; - ShapeTree indices_to_copy_; - ShapeTree> control_predecessors_; - ShapeTree read_only_indices_; - ShapeTree copy_overrides_; -}; + body->set_root_instruction(root_copy); -bool InstructionCopier::HasAllIndicesFalse() const { - bool all_indices_false = true; - indices_to_copy_.ForEachElement( - [&all_indices_false](const ShapeIndex& /*index*/, bool data) { - if (data) { - all_indices_false = false; - } - }); - return all_indices_false; + return Status::OK(); } -Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( - const TuplePointsToAnalysis& points_to_analysis) { - const PointsToSet& points_to = - points_to_analysis.GetPointsToSet(instruction_); - // Shallow copy the instruction if the points-to set of the top-level - // buffer is ambiguous. This is necessary because the backends must know - // statically what the top-level buffer of the result is. - if (points_to.element(/*index=*/{}).size() > 1) { - RecordIndex({}); +// Removes any control dependencies to or from the given instruction. +Status StripControlDependenciesFrom(HloInstruction* instruction) { + while (!instruction->control_successors().empty()) { + TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo( + instruction->control_successors().front())); + } + + while (!instruction->control_predecessors().empty()) { + TF_RETURN_IF_ERROR( + instruction->control_predecessors().front()->RemoveControlDependencyTo( + instruction)); } - // Multiple buffers within a parameter/constant may be live out, so collect - // a set of indices at which to copy first. - points_to.ForEachElement([this](const ShapeIndex& index, - const PointsToSet::BufferList& buffers) { - if (IsReadOnlyIndex(index)) { - return; - } - for (const LogicalBuffer* buffer : buffers) { - // pointee is the HloInstruction producing the buffer which may be - // liveout. - HloInstruction* pointee = buffer->instruction(); - if (pointee->opcode() == HloOpcode::kParameter || - pointee->opcode() == HloOpcode::kConstant) { - VLOG(2) << "Parameter or constant buffer " << buffer->ToString() - << " index: " << tensorflow::str_util::Join(index, ",") - << " may be live out of computation: " << pointee->ToString(); - RecordIndex(index); - break; - } - } - }); return Status::OK(); } -Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out) { - TF_RETURN_IF_ERROR( - RecordAmbiguousOrNonDistinctIndices(liveness.points_to_analysis())); - TF_RETURN_IF_ERROR(RecordIndicesWhichInterfereWithOtherInstruction( - liveness, other_instruction, read_only_indices_out)); +// Add kCopy instructions to the given module to guarantee there is no +// live-range interference. Generally interference can only occur around kWhile +// instructions which have update-in-place semantics. +Status AddCopiesToResolveInterference(HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile) { + TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); + } + } + } return Status::OK(); } -Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( - const TuplePointsToAnalysis& points_to_analysis) { - const PointsToSet& points_to = - points_to_analysis.GetPointsToSet(instruction_); - // Mapping from LogicalBuffer to index (used to detect non-distinct indices). - FlatMap> - buffer_to_source_indices; - points_to.ForEachElement( - [this, &buffer_to_source_indices]( - const ShapeIndex& index, const PointsToSet::BufferList& buffers) { - if (buffers.size() > 1) { - // Record ambiguous points-to set at 'index'. - if (!indices_to_copy_.element(index)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " at index: " << tensorflow::str_util::Join(index, ",") - << " with ambiguous points-to set."; - RecordIndex(index); +// Class for removing unnecessary copies from the module. +// +// kCopy instructions are added conservatively to guarantee no live range +// interference between HLO values. This class uses a more fine-grained analysis +// to remove some of these added copies which are not strictly necessary. +class CopyRemover { + public: + CopyRemover(const HloAliasAnalysis& alias_analysis, + const HloOrdering& ordering, HloModule* module) + : module_(module), + alias_analysis_(alias_analysis), + ordering_(ordering), + buffer_value_tracker_(*module, alias_analysis, ordering) {} + + // Try to elide the given copy. The copy is elided if the instruction is not + // necessary to prevent live-range interference of HLO values. Returns true if + // copy was elided. + // + // The copy instruction is not actually removed here. Instead it is left for + // dead in the graph. Later calls to DCE will remove the instruction. + StatusOr TryElideCopy(HloInstruction* copy) { + if (buffer_value_tracker_.TryElideCopy(copy)) { + TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy)); + TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0))); + return true; + } + return false; + } + + string ToString() const { + string out = StrCat("CopyRemover, module ", module_->name(), "\n"); + StrAppend(&out, " Buffer values, in dependency order:\n"); + for (const HloBuffer& buffer : alias_analysis_.buffers()) { + StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); + } + return out; + } + + private: + // Class which tracks the HLO values within each HLO buffer in the module + // during copy removal. + // + // The values are held in a linked list where there is one list for each + // buffer. Removing a copy instruction merges together the values in the + // source buffer of the copy to the destination buffer of the copy. This class + // tracks these value lists as copies are removed from the graph (and value + // lists are merged). + // + // The BufferValueTracker object is initialized to match the state of + // HloAliasAnalysis. However, as copies are removed this state diverges. The + // values-to-buffer mapping is maintained outside of HloAliasAnalysis because + // a fully updatable alias analysis is very slow. + class BufferValueTracker { + public: + // The values held in a single HLO buffer are represented using a linked + // list. An element type in this list is ValueNode. + // + // This linked list is hand-rolled to enable efficient splicing of lists + // using only references to list elements without knowing which lists are + // being spliced. std::list requires a reference to the list object to + // splice. + struct ValueNode { + explicit ValueNode(const HloValue* v) : value(v) {} + + const HloValue* value; + + // The uses are maintained outside of HloValue::uses() because + // HloValue::uses() is not updatable (a fully updatable dataflow analysis + // is slow). + std::vector uses; + + // next/prev elements in the linked list. The list is circularly linked so + // these values are never null for elements in the list. + ValueNode* prev = nullptr; + ValueNode* next = nullptr; + }; + + BufferValueTracker(const HloModule& module, + const HloAliasAnalysis& alias_analysis, + const HloOrdering& ordering) + : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { + // Construct a list for each HLO buffer in the alias analysis. Maintain a + // map from HloValue to the respective list element representing that + // value. The map is used to construct the copy info map below. + tensorflow::gtl::FlatMap value_to_node; + for (const HloBuffer& buffer : alias_analysis.buffers()) { + // Verify values contained in the buffer are strictly ordered. This + // should always be the case after adding copies to eliminate + // interference. Specifically, the addition of the control flow edges + // between copies added around aliased operations (kWhile) guarantees + // this strict order. + for (const HloValue* value_a : buffer.values()) { + for (const HloValue* value_b : buffer.values()) { + if (value_a != value_b) { + DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, + dataflow_) || + ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, + dataflow_)) + << value_a->ToShortString() << " and " + << value_b->ToShortString() << " are not ordered"; + } } } - // For each 'buffer': record a mapping from 'buffer' to 'index'. - for (const LogicalBuffer* buffer : buffers) { - buffer_to_source_indices[buffer].push_back(index); - } - }); - // Record all non-distinct indices detected in 'buffer_to_source_indices'. - for (const auto& buff_to_src : buffer_to_source_indices) { - if (buff_to_src.second.size() == 1) { - continue; + std::vector values = buffer.values(); + std::sort(values.begin(), values.end(), + [this](const HloValue* a, const HloValue* b) { + return ordering_.IsDefinedBefore(*a, *b); + }); + + // Create a list containing all of the values in the buffer. + AddValueList(values, &value_to_node); + } + + // Create copy_map_ which contains the source and destination values + // of all copies. + CreateCopyMap(module, value_to_node); + + XLA_VLOG_LINES(3, ToString()); + TF_DCHECK_OK(Verify()); } - for (const ShapeIndex& src_index : buff_to_src.second) { - // Record non-distinct points-to set at 'src_index'. - if (!indices_to_copy_.element(src_index)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " at index: " << tensorflow::str_util::Join(src_index, ",") - << " because of non-distinct points-to set."; - RecordIndex(src_index); + + // Add a list containing the given values to BufferValueTracker. This + // represents the values contained in a single buffer. For each value in + // 'values' an entry is created in value_to_node which indicates the + // respective ValueNode representing that value. + void AddValueList( + tensorflow::gtl::ArraySlice values, + tensorflow::gtl::FlatMap* value_to_node) { + ValueNode* tail = nullptr; + ValueNode* head = nullptr; + for (const HloValue* value : values) { + auto new_node = new ValueNode(value); + (*value_to_node)[value] = new_node; + + // Copy the HLO values's uses into the ValueNode for the value. These + // uses in ValueNode are updated as copies are removed. + new_node->uses.reserve(value->uses().size()); + for (const HloUse& use : value->uses()) { + new_node->uses.push_back(&use); + } + + // Connect the new node into the linked list. + if (tail == nullptr) { + head = new_node; + } else { + tail->next = new_node; + new_node->prev = tail; + } + tail = new_node; } + + // The linked list is circular so connect the head and tail. + tail->next = head; + head->prev = tail; + value_lists_.insert(head); } - } - return Status::OK(); -} -Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out) { - // Record all buffer indices for 'instruction_', which interfere with - // 'other_instruction' at the same index. - ShapeUtil::ForEachSubshape( - instruction_->shape(), - [this, &liveness, other_instruction, read_only_indices_out]( - const Shape& /*subshape*/, const ShapeIndex& index) { - if (IsReadOnlyIndex(index)) { - return; + // This method also fills in copy_map_ which indicates which nodes + // in the value lists corresponding to the source and destination values of + // kCopy instructions. value_to_node should map each HloValue to its + // respective ValueNode. + void CreateCopyMap( + const HloModule& module, + const tensorflow::gtl::FlatMap& + value_to_node) { + for (HloComputation* computation : module.computations()) { + for (HloInstruction* instruction : computation->instructions()) { + // Add copies with unambiguous source values to the map. Copies with + // ambiguous sources are not removable. + if (instruction->opcode() == HloOpcode::kCopy) { + const HloValueSet& src_value_set = + dataflow_.GetValueSet(instruction->operand(0)); + if (src_value_set.values().size() == 1) { + CopyNodes& copy_node = copy_map_[instruction]; + copy_node.dest = + value_to_node.at(&dataflow_.GetUniqueValueAt(instruction)); + copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue()); + } + } } - if (indices_to_copy_.element(index)) { - // Return if previous pass already set index. - return; + } + } + + ~BufferValueTracker() { + for (const ValueNode* head : value_lists_) { + const ValueNode* p = head; + do { + const ValueNode* tmp = p->next; + delete p; + p = tmp; + } while (p != head); + } + } + + // Verify invariants within the linked lists. + Status Verify() const { + for (const ValueNode* head : value_lists_) { + const ValueNode* p = head; + do { + // Verify links between elements are consistent. + TF_RET_CHECK(p->prev->next == p); + TF_RET_CHECK(p->next->prev == p); + + const HloInstruction* def = p->value->defining_instruction(); + if (def->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, def)) { + TF_RET_CHECK(copy_map_.at(def).dest == p); + } + for (const HloUse* use : p->uses) { + if (use->instruction->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, use->instruction)) { + TF_RET_CHECK(copy_map_.at(use->instruction).src == p); + } + } + + p = p->next; + } while (p != head); + } + return Status::OK(); + } + + // Try to elide the given copy. Elision of a copy is possible only if no + // live range interference is introduced by the copy's elimination. If + // elision is possible, then the internal state (value lists) are updated, + // and true is returned. Returns false otherwise. + bool TryElideCopy(const HloInstruction* copy) { + VLOG(2) << "Trying to remove " << copy->name(); + + if (!ContainsKey(copy_map_, copy)) { + VLOG(2) << copy->name() << " is not removable"; + return false; + } + + const CopyNodes& copy_node = copy_map_.at(copy); + ValueNode* src = copy_node.src; + ValueNode* dest = copy_node.dest; + DCHECK(src != nullptr); + DCHECK(dest != nullptr); + + auto is_live_range_before = [this](const ValueNode& a, + const ValueNode& b) { + if (LiveRangeBefore(a, b)) { + VLOG(2) << " Live range of " << a.value->ToShortString() + << " is before " << b.value->ToShortString(); + return true; + } else { + VLOG(2) << " Live range of " << a.value->ToShortString() + << " is not before " << b.value->ToShortString(); + return false; } - const auto& points_to_analysis = liveness.points_to_analysis(); - // Lookup buffers for 'instruction_' and 'other_instruction'. - const auto instruction_buffers = - points_to_analysis.GetPointsToSet(instruction_).element(index); - // If 'instruction_' has ambiguous points-to-set at 'index', it would - // have been recorded in a previous pass (and we would have returned - // early at the entry to this function). As a result, here we know that - // 'instruction_' has just one buffer in its points-to-set. - CHECK_EQ(1, instruction_buffers.size()); - const LogicalBuffer* instruction_buffer = instruction_buffers[0]; - - const auto other_instruction_buffers = - points_to_analysis.GetPointsToSet(other_instruction).element(index); - // Do not insert a copy if both instructions point at the same buffer. - // This eliminates unnecessary copies of read-only tuple elements. - // If 'instruction_' and 'other_instruction' point to the same buffer, - // then that buffer is not updated on the path between the two - // instructions. Therefore, any other (possibly interference-causing) - // users of that buffer from 'other_instruction' will see the same data, - // irrespective of whether we insert a copy of this buffer at - // 'instruction_' or not. - if (other_instruction_buffers.size() == 1 && - other_instruction_buffers[0]->id() == instruction_buffer->id()) { - if (read_only_indices_out != nullptr) { - *read_only_indices_out->mutable_element(index) = true; + }; + + VLOG(3) << copy->name() << " copies value " + << src->value->ToShortString(); + VLOG(3) << "Source buffer values: " << ValueListToString(src); + VLOG(3) << "Dest buffer values: " << ValueListToString(src); + + // A kCopy instruction copies an HLO value from a source buffer and + // defines an HLO value in a destination buffer. Most generally, the + // source and destination buffers may each hold more than one value at + // different points in the computation so we define the following: + // + // Values in source buffer: {s_0, ..., s_n} + // Values in destination buffer: {d_0, ..., d_m} + // + // A kCopy instruction between these buffers copies a value s_x in the + // source buffer and defines a value d_y in the destination buffer. The + // elision of a copy merges the source and destination buffers together, + // so the list of values for the source and destination buffers are + // merged. + // + // We handle two different cases for copy elision: + // + // (1) the kCopy defines the first value in the destination buffer (d_0). + // + // (2) the kCopy copies the last value in the source buffer (s_n). + // + // For the remaining case where the kCopy copies a not-last value from the + // source buffer to a not-first value of the destination buffer, the kCopy + // instruction cannot be removed. This case is generated, for example, if + // the kCopy copies a while body parameter of the loop state at one tuple + // index to a different tuple index in the while body root. Removal of the + // copy necessarily results in live range interference of values in the + // loop state at the two different tuple indices. + // + // We can only perform copy elision if the resulting merged values have + // totally ordered live ranges; otherwise the merged buffer would have + // live range interference. + if (IsHead(*dest)) { + // The copy copies an arbitrary value in the source buffer (call it s_x) + // and defines d_0, the first value in the destination buffer. After + // merging, the values in the combined buffer must be strictly ordered + // as follows** to elide the copy: + // + // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} + // + // Removing the copy eliminates d_0, and uses of d_0 become uses of + // s_x. In the above ordering, the live range of d_m must be ordered + // before the live range of s_{x+1} and the definition and all uses of + // s_x must be ordered before the definition of d_1. These conditions + // are checked below prior to elision. + // + // ** Technically it might be possible to have a non-interfering + // non-trivial interleaving of the values of the source and + // destination buffers in the resulting order. However, this case is + // slow and complicated to check and likely not worth it. So instead + // we simply check for the case where *all* values of the destination + // buffer (d_1 through d_m) are spliced into the point where the copy + // used to be. + VLOG(2) << copy->name() << " defines the first value in its buffer"; + ValueNode* next_dest = Next(*dest); + if (next_dest != nullptr) { + // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); + if (!is_live_range_before(*src, *next_dest)) { + return false; } - return; } - // We can't say anything about the ambiguity of 'other_instruction' at - // this point, so we need to check interference between the single - // buffer in the points-to set of 'instruction_' and all buffers in - // 'other_instruction_buffers'. - for (const LogicalBuffer* other_buffer : other_instruction_buffers) { - if (liveness.MayInterfere(*instruction_buffer, *other_buffer)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " instruction_buffer: " << instruction_buffer->ToString() - << " at index: " << tensorflow::str_util::Join(index, ",") - << " because of interference with buffer: " - << other_buffer->ToString(); - RecordIndex(index); - break; + ValueNode* next_src = Next(*src); + + if (next_src != nullptr) { + // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. + ValueNode* last_dest = dest->prev; + DCHECK(IsTail(*last_dest)); + if (!is_live_range_before(*last_dest, *next_src)) { + return false; } } - }); - return Status::OK(); -} -// This is called when 'instruction_' is a while body root, and 'parameter' is -// the while body parameter. We record all users of all aliases of 'parameter' -// as control predecessors, so that when we add a copy of 'instruction_', we can -// mark the control dependencies. This is necessary because points-to and -// liveness analysis doesn't know about the aliasing between the while body root -// and param. Without these control dependencies, the copy might get scheduled -// to run at a point that interferes with users of the buffer. -Status InstructionCopier::RecordControlPredecessors( - const TuplePointsToAnalysis& points_to_analysis, - HloInstruction* parameter) { - return indices_to_copy_.ForEachElementWithStatus( - [this, &points_to_analysis, parameter](const ShapeIndex& index, - bool will_copy) { - if (will_copy) { - TF_ASSIGN_OR_RETURN( - const LogicalBuffer* buffer, - points_to_analysis.GetBufferDefinedAt(parameter, index)); - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - user, points_to_analysis)) { - continue; - } - - if (user != instruction_) { - control_predecessors_.mutable_element(index)->push_back(user); - } - } + // Splice in destination buffer values list right after 'src'. + SpliceAfter(dest, src); + } else if (IsTail(*src)) { + // The copy copies the last value in the source buffer, s_n, and defines + // an arbitrary value in the destination buffer, d_y. After + // merging, the values in the combined buffer must be strictly ordered + // as follows** to elide the copy: + // + // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m} + // + // Removing the copy eliminates d_y, and uses of d_y become uses of + // s_n. To enforce the above order, the live range of d_{y-1} must be + // before the live range of s_0, and the live range of s_n must be + // before the live range of d_{y+1}. + // + // ** See comment above in the code handling Case (1). + VLOG(2) << copy->name() << " copies the last value (" + << src->value->ToShortString() << ") in its buffer"; + + ValueNode* prev_dest = Prev(*dest); + // nullptr condition handled above in the first 'if' case. + DCHECK(prev_dest != nullptr); + ValueNode* first_src = src->next; + DCHECK(IsHead(*first_src)); + if (!is_live_range_before(*prev_dest, *first_src)) { + // Live range of value d_{y-1} is not before s_0. + return false; + } + ValueNode* next_dest = Next(*dest); + if (next_dest != nullptr) { + if (!is_live_range_before(*src, *next_dest)) { + // Live range of value s_n is not before d_{y+1}. + return false; } } - return Status::OK(); - }); -} -// Recursively inserts copies of 'instruction' tuple element buffers at -// indices in 'indices_to_copy_', expanding tuples as needed. -HloInstruction* InstructionCopier::CopyTuple(HloInstruction* instruction, - ShapeIndex* index) { - const int64 num_tuple_elements = - ShapeUtil::TupleElementCount(instruction->shape()); - std::vector elem_copies(num_tuple_elements); - for (int64 i = 0; i < num_tuple_elements; ++i) { - HloInstruction* elem; - if (instruction->opcode() == HloOpcode::kTuple) { - // If the instruction is already a Tuple instruction, we know that the - // element buffers are aliased, so we can just grab the operand directly. - elem = instruction->mutable_operand(i); - } else { - // Otherwise we need to add a GTE to unpack the element out of the tuple. - elem = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - } - index->push_back(i); - if (ShapeUtil::IsTuple(elem->shape())) { - elem_copies[i] = CopyTuple(elem, index); - } else if (!indices_to_copy_.element(*index)) { - elem_copies[i] = elem; - } else if (HloInstruction* copy_override = GetCopyOverride(*index)) { - elem_copies[i] = copy_override; - } else { - HloInstruction* elem_copy = elem->parent()->AddInstruction( - HloInstruction::CreateUnary(elem->shape(), HloOpcode::kCopy, elem)); - for (HloInstruction* control_predecessor : - control_predecessors_.element(*index)) { - VLOG(2) << "Adding control dependency from " - << control_predecessor->ToString() << " to " - << elem_copy->ToString(); - TF_CHECK_OK(control_predecessor->AddControlDependencyTo(elem_copy)); + // Splice source buffer values list right after 'prev_dest'. + SpliceAfter(first_src, prev_dest); + } else { + VLOG(2) + << copy->name() + << " copies value in middle of source buffer to value in middle " + "of destination buffer"; + return false; } - elem_copies[i] = elem_copy; + + RemoveCopyValue(dest); + + XLA_VLOG_LINES(4, ToString()); + TF_DCHECK_OK(Verify()); + + return true; } - index->pop_back(); - } - return instruction->parent()->AddInstruction( - HloInstruction::CreateTuple(elem_copies)); -} -// Inserts copies of 'instruction_' buffers at indices in 'indices_to_copy_'. -HloInstruction* InstructionCopier::Copy() { - ShapeIndex index; - HloInstruction* copy; - if (ShapeUtil::IsTuple(instruction_->shape())) { - copy = CopyTuple(instruction_, &index); - } else { - copy = instruction_->parent()->AddInstruction(HloInstruction::CreateUnary( - instruction_->shape(), HloOpcode::kCopy, instruction_)); - } - for (HloInstruction* user : copy_users_) { - VLOG(2) << "Adding copy between instruction: " << instruction_->name() - << " and user: " << user->name(); - TF_CHECK_OK(instruction_->ReplaceUseWith(user, copy)); + // Delete the given ValueNode associated with a elided kCopy + // instruction. This should be called after splicing the value lists of the + // source and destination buffers together. + void RemoveCopyValue(ValueNode* copy_value_node) { + CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(), + HloOpcode::kCopy); + ValueNode* operand_node = copy_value_node->prev; + CHECK(operand_node != copy_value_node); + + VLOG(2) << "Removing copy " << operand_node->value->ToShortString() + << " => " << copy_value_node->value->ToShortString(); + + // Splice out the copy value node. + operand_node->next = copy_value_node->next; + copy_value_node->next->prev = operand_node; + + // Patch up uses. Remove use of copy from operand_node uses. + auto it = + std::find_if(operand_node->uses.begin(), operand_node->uses.end(), + [copy_value_node](const HloUse* use) { + return use->instruction == + copy_value_node->value->defining_instruction(); + }); + CHECK(it != operand_node->uses.end()); + operand_node->uses.erase(it); + + // If the elided copy has any uses which are themselves kCopy instructions + // then patch up the copy info to reflect the that this kCopy instruction + // has a different operand (the operand of the elided copy). + for (const HloUse* copy_use : copy_value_node->uses) { + operand_node->uses.push_back(copy_use); + if (copy_use->instruction->opcode() == HloOpcode::kCopy) { + copy_map_.at(copy_use->instruction).src = operand_node; + } + } + + // Delete the copy info and the value node. + copy_map_.erase(copy_value_node->value->defining_instruction()); + delete copy_value_node; + } + + // Returns true if the live range of given value 'a' is before the live + // range of 'b'. + // + // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not + // updated as copies are removed. + bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { + if (a.uses.empty()) { + VLOG(2) << "Empty uses"; + return ordering_.IsDefinedBefore(*a.value, *b.value); + } + for (const HloUse* use : a.uses) { + VLOG(2) << "use: " << *use; + VLOG(2) << "is before:" << *b.value; + if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { + VLOG(2) << "Not before"; + return false; + } + } + return true; + } + + // Returns whether 'node' is the last node in its list. + bool IsTail(const ValueNode& node) const { + return ContainsKey(value_lists_, node.next); + } + + // Returns whether 'node' is the first node in its list. + bool IsHead(const ValueNode& node) const { + return ContainsKey(value_lists_, &node); + } + + // Returns the next node in the list after 'node'. If 'node' is the + // tail, then nullptr is returned. + ValueNode* Next(const ValueNode& node) const { + if (IsTail(node)) { + return nullptr; + } else { + return node.next; + } + } + + // Returns the previous node in the list before 'node'. If 'node' + // is the head, then nullptr is returned. + ValueNode* Prev(const ValueNode& node) const { + if (IsHead(node)) { + return nullptr; + } else { + return node.prev; + } + } + + // Splices the entire linked list with 'head' as its head right after the + // node 'insert_after' in another linked list. + void SpliceAfter(ValueNode* head, ValueNode* insert_after) { + DCHECK(IsHead(*head)); + value_lists_.erase(head); + + ValueNode* tail = head->prev; + tail->next = insert_after->next; + insert_after->next->prev = tail; + + insert_after->next = head; + head->prev = insert_after; + } + + string ValueListToString(const ValueNode* element) { + const ValueNode* head = element; + while (!IsHead(*head)) { + head = Prev(*head); + } + std::vector values; + for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { + values.push_back(p->value); + } + return StrCat("{", + Join(values, ", ", + [](string* s, const HloValue* value) { + StrAppend(s, value->ToShortString()); + }), + "}"); + } + + string ToString() const { + string out = StrCat("BufferValueTracker:\n"); + StrAppend(&out, " Def-use chains in each buffer:\n"); + for (const ValueNode* head : value_lists_) { + StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), + ":\n"); + const ValueNode* p = head; + do { + StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", + Join(p->uses, "; ", + [](string* s, const HloUse* use) { + StrAppend(s, use->ToString()); + }), + "\n"); + + p = p->next; + } while (p != head); + } + StrAppend(&out, " Potentially removable copies:\n"); + for (const auto& pair : copy_map_) { + const HloInstruction* copy = pair.first; + const CopyNodes& copy_info = pair.second; + + StrAppend(&out, " ", copy->name(), " : ", + copy_info.src->value->ToShortString(), " => ", + copy_info.dest->value->ToShortString(), "\n"); + } + return out; + } + + private: + const HloDataflowAnalysis& dataflow_; + const HloOrdering& ordering_; + + // The heads of all the value lists. Each value list represents the HLO + // values contained in a particular HLO buffer. The values in the list are + // in dependency order. + tensorflow::gtl::FlatSet value_lists_; + + // Copy removal requires fast access to the value list elements + // corresponding to the source and destination values of the kCopy + // instruction. This data structure holds pointers to these elements for + // each kCopy instruction in the graph. + struct CopyNodes { + // The source and destinations values of the kCopy instruction. + ValueNode* src = nullptr; + ValueNode* dest = nullptr; + }; + tensorflow::gtl::FlatMap copy_map_; + }; + + HloModule* module_; + const HloAliasAnalysis& alias_analysis_; + const HloOrdering& ordering_; + + // Object tracking the HLO values contained in each HLO buffer. + BufferValueTracker buffer_value_tracker_; +}; + +// Try to remove as many copies from the module as possible without introducing +// live range interference. Copy instructions (identified by their unique id) in +// the set copies_to_exclude are not considered for removal. +Status RemoveUnnecessaryCopies( + const HloOrdering& ordering, + const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + CopyRemover copy_remover(*alias_analysis, ordering, module); + XLA_VLOG_LINES(3, copy_remover.ToString()); + + tensorflow::gtl::FlatSet existing_copies; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy && + !ContainsKey(copies_to_exclude, instruction->unique_id())) { + TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + } + } } - return copy; + + return Status::OK(); } -// The 'read_only_indices' are initialized based on points-to analysis on the -// while body corresponding to 'while_hlo'. If the init buffer corresponding to -// a read-only index aliases with a constant, it cannot be considered read-only, -// and must be copied. This is necessary because BufferAssignment does not -// currently assign an allocation for constants (b/32248867). -// This function performs this fix-up of 'read_only_indices'. +// Add copies to address special constraints on the roots of computations not +// related to live range interference: // -// Returns a ShapeTree of copy_overrides, which implements an optimization to -// allow multiple while loops that share the same read-only constants to -// share a single copy. -StatusOr> RevertReadOnlyIndicesForConstants( - const HloInstruction* while_hlo, - const TuplePointsToAnalysis& points_to_analysis, - ShapeTree* read_only_indices, - FlatMap* shared_copies) { - const HloInstruction* init_hlo = while_hlo->operand(0); - const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo); - - // Mapping from LogicalBuffer to index (used to detect non-distinct indices). - FlatSet buffer_set; - - ShapeTree copy_overrides(init_hlo->shape()); - points_to.ForEachElement([init_hlo, read_only_indices, shared_copies, - &buffer_set, ©_overrides]( - const ShapeIndex& index, - const PointsToSet::BufferList& buffers) { - // Look for read-only entry parameters. - if (!read_only_indices->element(index)) { - return; - } - for (const LogicalBuffer* buffer : buffers) { - HloInstruction* pointee = buffer->instruction(); - const bool is_constant = pointee->opcode() == HloOpcode::kConstant; - if (!is_constant) { - continue; - } +// (1) Entry computation root must be unambiguous and distinct. +// +// (2) Any computation called by a kCall instruction must have an +// unambiguous root. +// +// (3) Constants and parameters cannot be live out of the entry computation +// +Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + + // Identify which shape indices of which instructions need to be copied. Store + // these results in 'instructions_to_copy'. + std::unordered_map> instructions_to_copy; + auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction, + const ShapeIndex& index) { + auto it = instructions_to_copy.find(instruction); + if (it == instructions_to_copy.end()) { + auto it_added = instructions_to_copy.emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); + it = it_added.first; + } + *it->second.mutable_element(index) = true; + }; - // We have found an constant that is read-only in - // the while body. These buffers are managed by the caller, and cannot - // be aliased with HLO buffers. Revert this read-only index, - // to allow it to be copied. - *read_only_indices->mutable_element(index) = false; - - // Optimization to allow multiple while loops that share the same - // read-only entry constants to share a single copy. - // Only unambiguous and distinct array-shaped buffers are allowed, to - // reduce code complexity. The shape of the entry parameter must be - // identical to the shape of the init_hlo at this index, to ensure - // there were no intervening bitcast or GTE instructions, which are - // also hard to handle. - const Shape& pointee_shape = pointee->shape(); - const Shape& init_shape = - ShapeUtil::GetSubshape(init_hlo->shape(), index); - if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) && - ShapeUtil::Equal(pointee_shape, init_shape) && - buffer_set.count(buffer) < 1) { - HloInstruction** copy = &(*shared_copies)[pointee]; - if (*copy == nullptr) { - *copy = pointee->parent()->AddInstruction(HloInstruction::CreateUnary( - pointee_shape, HloOpcode::kCopy, pointee)); + // Iterate through values of all constants and entry parameters. These values + // are special because they are held in read-only buffers. If any of these + // values share a buffer with other values (for example, the init value of a + // while is a constant) then copy the value at its definition and replace all + // its uses with the copy. + for (const HloValue* value : alias_analysis->dataflow_analysis().values()) { + if (ValueIsReadOnly(*value) && + alias_analysis->GetBufferContainingValue(*value).values().size() > 1) { + VLOG(2) << "Value " << value->ToShortString() + << " is read only, but its buffer contains more than one value. " + "Copying."; + add_index_to_copy(value->defining_instruction(), value->defining_index()); + } + } + + // Identify copies which must be added at root instructions + for (HloComputation* computation : module->computations()) { + const CallGraphNode& node = call_graph.GetNode(computation); + if (node.context() == CallContext::kParallel) { + continue; + } + TF_RET_CHECK(node.context() == CallContext::kSequential); + + const bool is_entry = computation == module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + + // Mark nondistinct/ambiguous indices. + tensorflow::gtl::FlatSet seen; + ShapeUtil::ForEachSubshape( + root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector buffers_at_index = + alias_analysis->ComputeBuffersAt(root, index); + bool buffer_seen_before = false; + for (const HloBuffer* buffer : buffers_at_index) { + buffer_seen_before |= !seen.insert(buffer).second; + } + if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) { + VLOG(2) << "Index " << index << " of root of computation " + << computation->name() << " (" << root->name() + << ") has ambiguous or non-distinct buffer. Copying."; + add_index_to_copy(root, index); + } + }); + + // For entry instructions, mark any parameter or constant values. + if (is_entry) { + for (const auto& pair : + alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (ValueIsReadOnly(*value)) { + VLOG(2) << "Root of entry computation (" << root->name() + << ") has constant or entry parameter value at index " + << index << ". Copying."; + add_index_to_copy(root, index); + } } - // Add the copy as an override. - *copy_overrides.mutable_element(index) = *copy; } + } + } - // Tracks whether this current buffer is distinct. - buffer_set.insert(buffer); + // Add copy instructions indicated in 'instructions_to_copy' to the module. + for (const auto& pair : instructions_to_copy) { + HloInstruction* instruction = pair.first; + const ShapeTree& indices_to_copy = pair.second; - // We've already reverted the read-only index and handled the - // single-copy optimization above, so there's nothing more to do. - break; + std::vector users = instruction->users(); + TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, + instruction->parent()->DeepCopyInstruction( + instruction, &indices_to_copy)); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); + } + if (instruction == instruction->parent()->root_instruction()) { + instruction->parent()->set_root_instruction(deep_copy); } - }); - return copy_overrides; + } + + return Status::OK(); +} + +Status VerifyNoLiveRangeInterference(HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + DependencyHloOrdering ordering(module); + TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); + return Status::OK(); } -} // anonymous namespace - -// NOTE: This is only called by gpu::CopyInsertion. It's not called here in the -// base class, since the regular CopyInsertion logic above selectively copies -// tuple elements, while this method assumes all buffers need to be deep copied. -StatusOr CopyInsertion::FindOrInsertCopy(HloInstruction* hlo) { - auto copy_it = inserted_copies_.find(hlo); - if (copy_it == inserted_copies_.end()) { - HloInstruction* copy = hlo->parent()->DeepCopyInstruction(hlo).ValueOrDie(); - inserted_copies_.insert({hlo, copy}); - return copy; - } else { - return copy_it->second; +void MaybeDumpModule(const string& message, const HloModule& module) { + if (VLOG_IS_ON(3)) { + VLOG(3) << message; + XLA_VLOG_LINES(3, module.ToString()); + hlo_graph_dumper::MaybeDumpHloModule(module, message); } } +} // namespace + StatusOr CopyInsertion::Run(HloModule* module) { - bool changed = false; - VLOG(2) << "CopyInsertion for module " << module->name(); + // Copy insertion is performed in three steps: + // + // (1) Add copies conservatively to guarantee that there is no live-range + // interference. This is done simplistically and usually results in more + // copies than is strictly necessary. + // + // (2) Using a more fine-grained analysis, remove as many copies that were + // added in (1) as possible while ensuring no live-range interference. + // + // (3) Add copies to resolve issues not related to live range interference + // such as parameters and constants live out of the entry computation. + // + // We add copies then remove them (step (1) then (2)) rather than simply + // adding only the copies that are necessary because, in general, it is + // difficult to figure out the minimal set of copies to add once there is + // interference. On the other hand, it is easy to determine if removing a copy + // will introduce interference. + // + // The final copy insertion in (3) is done separately to simplify the + // implementation of copy removal in (2) which is the most complicated part of + // the pass. As is, copy removal only has to reason about live range + // interference. If all copies were added in step (1) then copy removal would + // also have to reason about things like constants and parameters live out of + // the computation. + MaybeDumpModule("before copy insertion", *module); - TF_ASSIGN_OR_RETURN( - std::unique_ptr liveness, - BufferLiveness::Run(module, MakeUnique(module))); - const auto& points_to_analysis = liveness->points_to_analysis(); - XLA_VLOG_LINES(2, points_to_analysis.ToString()); - XLA_VLOG_LINES(2, module->ToString()); - - // Gather all while body computations and while instructions. - FlatSet while_body_computations; - std::vector while_instructions; - for (auto* computation : module->computations()) { + std::unique_ptr call_graph = CallGraph::Build(module); + if (!call_graph->IsFlattened()) { + return FailedPrecondition( + "Call graph must be flattened before copy insertion."); + } + + // Gather Ids of existing kCopy instructions in the module. We avoid removing + // these copies (except via DCE in TupleSimplifier) because they may have been + // added for reasons not considered by copy insertion (eg, layout assignment). + // Instruction id is used instead of HloInstruction* because the pointer + // values may be recycled. + tensorflow::gtl::FlatSet existing_copies; + for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kWhile) { - while_body_computations.insert(instruction->while_body()); - while_instructions.push_back(instruction); + if (instruction->opcode() == HloOpcode::kCopy) { + existing_copies.insert(instruction->unique_id()); } } } - // Collect instruction buffer indices to copy in 'instructions_to_copy'. - std::vector instructions_to_copy; - - // Add copies of computation root instructions, if needed. - FlatMap> while_body_read_only_indices; - for (auto* computation : module->MakeNonfusionComputations()) { - VLOG(2) << "computation " << computation->name(); - InstructionCopier root_copier(computation->root_instruction(), - /*copy_users=*/{}); - if (while_body_computations.count(computation) > 0) { - // Record root indices to copy for while body sub-computations. We do not - // need to call RecordIndicesWhichPointToParamOrConstant for the while - // body root instruction here, because any necessary copies needed to - // avoid constants or parameters in the output are handled by while.init - // operand copy insertion below (which will share an allocation). - HloInstruction* while_body_param = computation->parameter_instruction(0); - ShapeTree read_only_indices(while_body_param->shape()); - TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers( - *liveness, while_body_param, &read_only_indices)); - while_body_read_only_indices[computation] = read_only_indices; - - // Mark control predecessors, based on the body param, for any copies - // we'll be inserting. This ensures the copy doesn't run too early. - TF_RETURN_IF_ERROR(root_copier.RecordControlPredecessors( - points_to_analysis, while_body_param)); - } else { - // Record root indices to copy for general computations. - TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant( - points_to_analysis)); + TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module)); + + // Simplify the tuple structures introduced by the deep copies. This should be + // done before removing copies (RemoveUnnecessaryCopies) because tuple + // simplification changes dependencies in the graph which changes live range + // interference in the graph. Also run DCE to remove the dead Tuple/GTE + // instructions introduced by tuple simplification. + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + + TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + + MaybeDumpModule("after adding copies to resolve interference", *module); + + DependencyHloOrdering ordering(module); + TF_RETURN_IF_ERROR( + RemoveUnnecessaryCopies(ordering, existing_copies, module)); + + MaybeDumpModule("after removing unnecessary copies", *module); + + TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); + + MaybeDumpModule("after adding special-case copies", *module); + + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + + MaybeDumpModule("after copy insertion", *module); + + if (VLOG_IS_ON(1)) { + int64 num_total_copies = 0; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + num_total_copies++; + } + } } - instructions_to_copy.push_back(root_copier); + VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size(); + VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; } - // Add copies of while 'init' operand instructions, if needed. 'shared_copies' - // is used to ensure that multiple while loops can share a single copy of the - // same entry parameter or constant, if all loops use it read-only. - // - // TODO(b/33301720) Remove redundant while instruction copies. - FlatMap shared_copies; - for (HloInstruction* while_hlo : while_instructions) { - // Fix read_only_indices to account for entry constants. Also - // initialize copy_overrides, which ensures a single copy for each read-only - // constant that is used in multiple while loops. - ShapeTree* read_only_indices = - &while_body_read_only_indices[while_hlo->while_body()]; - TF_ASSIGN_OR_RETURN( - const ShapeTree copy_overrides, - RevertReadOnlyIndicesForConstants(while_hlo, points_to_analysis, - read_only_indices, &shared_copies)); - // Create InstructionCopier for init operand of while instruction. - HloInstruction* init_hlo = while_hlo->mutable_operand(0); - InstructionCopier init_copier(init_hlo, {while_hlo}); - init_copier.SetReadOnlyIndices(*read_only_indices); - init_copier.SetCopyOverrides(copy_overrides); - // Record 'init' buffer indices which point-to a Constant or Parameter. - TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant( - points_to_analysis)); - // Record indices necessary to colocate while and init operand buffers. - TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers( - *liveness, while_hlo, /*read_only_indices_out=*/nullptr)); - instructions_to_copy.push_back(init_copier); + return true; +} + +namespace { + +bool IsWhileBody(const HloComputation* computation, + const CallGraph& call_graph) { + const CallGraphNode& node = call_graph.GetNode(computation); + + if (node.context() == CallContext::kSequential && + !node.caller_callsites().empty()) { + // Callgraph should be flattened so sequential context computations can + // have at most one caller. + CHECK_EQ(node.caller_callsites().size(), 1); + const HloInstruction* calling_instruction = + node.caller_callsites()[0].instruction(); + if (calling_instruction->opcode() == HloOpcode::kWhile && + calling_instruction->while_body() == node.computation()) { + return true; + } } + return false; +} - for (InstructionCopier& to_copy : instructions_to_copy) { - if (to_copy.HasAllIndicesFalse()) { +} // namespace + +/* static */ StatusOr CopyInsertion::AddCopiesForBufferAssignment( + HloModule* module) { + std::unique_ptr call_graph = CallGraph::Build(module); + TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, + HloDataflowAnalysis::Run(module)); + + bool changed = false; + + // If a buffer live out of a computation is a constant, a parameter, or not + // defined in the computation, then copy it to account for the limited + // computation-scoped analysis in buffer assignment. An exception to this rule + // is the while body which is handled properly without copies. + for (HloComputation* computation : module->computations()) { + if (computation == module->entry_computation() || + IsWhileBody(computation, *call_graph)) { continue; } - changed = true; - // Copy instruction at recorded buffer indices. - HloComputation* computation = to_copy.instruction()->parent(); - HloInstruction* copy = to_copy.Copy(); - if (to_copy.instruction() == computation->root_instruction()) { - computation->set_root_instruction(copy); + HloInstruction* root = computation->root_instruction(); + ShapeTree indices_to_copy(root->shape(), /*init_value=*/false); + bool copy_root = false; + for (const auto& pair : dataflow->GetInstructionValueSet(root)) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + HloInstruction* def = value->defining_instruction(); + if (def->parent() != computation || + def->opcode() == HloOpcode::kConstant || + def->opcode() == HloOpcode::kParameter) { + *indices_to_copy.mutable_element(index) = true; + copy_root = true; + } + } + } + if (copy_root) { + TF_ASSIGN_OR_RETURN( + HloInstruction * root_copy, + computation->DeepCopyInstruction(root, &indices_to_copy)); + computation->set_root_instruction(root_copy); + changed = true; } } - VLOG(3) << "After copy insertion for module " << module->name(); - XLA_VLOG_LINES(3, module->ToString()); + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed, + tuple_simplifier.Run(module)); + TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module)); - return changed; + return changed || tuple_simplifier_changed || dce_changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 28bb62e40c7674960dbb1bb63dc8967b06956028..65e3d31e347e2cb249a072e7d06ca10c55401748 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -25,12 +25,25 @@ limitations under the License. namespace xla { -// HLO pass which inserts a copy of the root instruction (creating a new root) -// if the root is or points-to any constant or parameter instruction. -// If the root instruction is a Tuple, only tuple elements which point to -// constant or parameter instructions will be copied. -// Copy insertion is necessary because constant and parameter arrays have -// different lifetimes than computation results. +// Copy insertion is a legalization HLO pass which inserts copies (kCopy +// instructions) to eliminate several kinds of problems in the HLO module. +// +// (1) Entry parameter or a constant live out of the entry computation. Entry +// computation arguments and constants have different lifetimes than the +// computation result and cannot share the same allocation. Parameters and +// constants live out of non-entry computations do not need copies. +// +// (2) Different values which are simultaneously live and which must be held +// in the same buffer. This can occur in while bodies. Specifically, the +// while loop state (the arguments to the while instruction) is updated +// in-place and the update may clobber the value from the previous +// iteration before the previous value is dead. Computations called from +// kCall instructions do not need such copies because kCall has no update +// in-place semantics. +// +// (3) The buffer set of the root instruction of the entry computation must be +// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and +// InstructionAliasSet::IsDistinct return true. class CopyInsertion : public HloPassInterface { public: tensorflow::StringPiece name() const override { return "copy-insertion"; } @@ -39,14 +52,16 @@ class CopyInsertion : public HloPassInterface { // (copies were inserted). StatusOr Run(HloModule* module) override; - protected: - // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making - // duplicate copies. - StatusOr FindOrInsertCopy(HloInstruction* hlo); - - // A map containing all copies inserted during the copy insertion pass. The - // key is the copied instruction and the value is the copy. - tensorflow::gtl::FlatMap inserted_copies_; + // The CPU and GPU backend need additional copies added due to deficiencies in + // buffer assignment. Specifically, copies are needed for constants live-out + // of computations, and for values which are live-in and live-out of the same + // computation. These copies are needed because buffer-assignment uses a + // computation-scoped analyis (TuplePointsToAnalysis) and has limited + // visibility across computation boundaries. This method adds these necessary + // copies. Returns whether the module was modified. + // + // TODO(b/62548313): Remove this when buffer assignment is module-scoped. + static StatusOr AddCopiesForBufferAssignment(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index a2eacc5c7dae2424e01fdd49d82546b5488d4312..8388574716ad1b78eb8868a8cd732005050b3310 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -17,18 +17,19 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace op = xla::testing::opcode_matchers; @@ -37,35 +38,53 @@ namespace { using ::testing::UnorderedElementsAre; +int64 CountCopies(const HloComputation& computation) { + int64 count = 0; + for (const auto& instruction : computation.instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + count++; + } + } + return count; +} + +int64 CountCopies(const HloModule& module) { + int64 count = 0; + for (const auto& computation : module.computations()) { + count += CountCopies(*computation); + } + return count; +} + +int64 CountControlEdges(const HloComputation& computation) { + int64 count = 0; + for (const auto& instruction : computation.instructions()) { + count += instruction->control_successors().size(); + } + return count; +} + +int64 CountControlEdges(const HloModule& module) { + int64 count = 0; + for (const auto& computation : module.computations()) { + count += CountControlEdges(*computation); + } + return count; +} + class CopyInsertionTest : public HloTestBase { protected: void InsertCopies(HloModule* module) { CopyInsertion copy_insertion; - EXPECT_IS_OK(copy_insertion.Run(module).status()); - - // Verify the points to set of the root of the computation after copy - // insertion contains no constants or parameters, and is distinct and - // non-ambiguous. - auto points_to_analysis = - TuplePointsToAnalysis::Run(module).ConsumeValueOrDie(); - const auto& points_to = points_to_analysis->GetPointsToSet( - module->entry_computation()->root_instruction()); - EXPECT_TRUE(points_to.IsDistinct()); - EXPECT_TRUE(!points_to.IsAmbiguous()); - - auto maybe_live_out_buffers = - points_to_analysis - ->GetPointsToSet(module->entry_computation()->root_instruction()) - .CreateFlattenedSet(); - - for (const LogicalBuffer* buffer : maybe_live_out_buffers) { - EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant); - EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter); - } + ASSERT_IS_OK(copy_insertion.Run(module).status()); } + + const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); }; TEST_F(CopyInsertionTest, SingleParameter) { + // Computation is a single parameter passed into a tuple. The parameter should + // be copied before entering the tuple. auto builder = HloComputation::Builder(TestName()); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -77,14 +96,15 @@ TEST_F(CopyInsertionTest, SingleParameter) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)))); + op::Tuple(op::Copy(x))); } TEST_F(CopyInsertionTest, SingleConstant) { + // Computation is a single constant passed into a tuple. The parameter should + // be copied before entering the tuple. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); @@ -96,11 +116,42 @@ TEST_F(CopyInsertionTest, SingleConstant) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)))); + op::Tuple(op::Copy(constant))); +} + +TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { + // Verify that an kCopy instructions which exist in the pass before + // copy-insertion remain in the graph after copy-insertion. + auto module = CreateNewModule(); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); + HloInstruction* add_copy = builder.AddInstruction( + HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add)); + + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(CountCopies(*module), 3); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 3); + + EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())))); } TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { @@ -127,12 +178,12 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)), - op::Copy(old_root->operand(1)), old_root->operand(2))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y))); } TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { @@ -165,6 +216,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Tuple(op::Copy(op::GetTupleElement(old_root)), @@ -187,6 +239,7 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -208,6 +261,7 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -227,11 +281,11 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)))); + op::Tuple(op::Copy(bitcast))); } TEST_F(CopyInsertionTest, NestedTupleParameter) { @@ -257,6 +311,8 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 3); + HloInstruction* new_root = module->entry_computation()->root_instruction(); EXPECT_NE(old_root, new_root); @@ -283,7 +339,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { ShapeUtil::MakeShape(F32, {42})}), "param0")); - // The return value of the computation is the zero-th elemnt of the nested + // The return value of the computation is the zero-th element of the nested // tuple. This element is itself a tuple. auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); @@ -293,12 +349,13 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { EXPECT_EQ(gte, module->entry_computation()->root_instruction()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(op::GetTupleElement(old_root)), - op::Copy(op::GetTupleElement(old_root)))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))), + op::Copy(op::GetTupleElement(op::GetTupleElement(param))))); } TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { @@ -331,6 +388,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -346,12 +404,10 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // The parameter 'nested' specifies the loop state shape from which to // read the induction variable. std::unique_ptr BuildConditionComputation( - bool nested = false) { + const Shape& loop_state_shape) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(10))); - const Shape& loop_state_shape = - nested ? nested_loop_state_shape_ : loop_state_shape_; auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); auto induction_variable = @@ -582,7 +638,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto loop_state_init = builder.AddInstruction( HloInstruction::CreateTuple({induction_var_init, inner_init})); auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( - loop_state_shape_, condition, body, loop_state_init)); + loop_state_init->shape(), condition, body, loop_state_init)); module_->AddEntryComputation(builder.Build()); return while_hlo; } @@ -658,11 +714,28 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // Take a reference to 'data_init' to make it interfere with while result. - builder.AddInstruction(HloInstruction::CreateBinary( + auto add = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data_init, one_vec)); - return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, - &builder); + auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_, + data_init, &builder); + + // Add an additional binary operation operating on the while and the + // interfering add so that neither operation is dead. + auto gte = xla_while->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1)); + auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kSubtract, add, gte)); + auto gte0 = xla_while->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0)); + auto tuple = xla_while->parent()->AddInstruction( + HloInstruction::CreateTuple({gte0, sub})); + + xla_while->parent()->set_root_instruction(tuple); + + return xla_while; } HloInstruction* BuildWhileInstructionWithCustomInit( @@ -672,8 +745,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0))); - auto condition = - module_->AddEmbeddedComputation(BuildConditionComputation(nested)); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape)); auto body = module_->AddEmbeddedComputation( BuildIndependentBodyComputation(nested)); auto loop_state_init = builder->AddInstruction( @@ -706,23 +779,21 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // CopyInsertion pass should not generate any copies. // TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { - auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body = module_->AddEmbeddedComputation(BuildIndependentBodyComputation()); auto while_hlo = BuildWhileInstruction(condition, body); - const HloInstruction* old_init = while_hlo->operand(0); - HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - HloInstruction* new_root = body->root_instruction(); - const HloInstruction* new_init = while_hlo->operand(0); - // No copies should be inserted so root should not be updated. - EXPECT_EQ(old_root, new_root); + // Body should have no copies as the adds can be done inplace. + EXPECT_EQ(CountCopies(*body), 0); + EXPECT_EQ(CountControlEdges(*module_), 0); - // Both init indices need copies. - EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + // Both init indices need copies as they are constants. + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); } // Tests while body computation with dependent tuple elements: @@ -737,20 +808,33 @@ TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { // Tuple(Copy(out0), out1) // TEST_F(WhileCopyInsertionTest, DependentTupleElements) { - auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation()); auto while_hlo = BuildWhileInstruction(condition, body); - const HloInstruction* old_init = while_hlo->operand(0); - HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - HloInstruction* new_root = body->root_instruction(); - const HloInstruction* new_init = while_hlo->operand(0); - EXPECT_THAT(new_root, - op::Tuple(op::Copy(old_root->operand(0)), old_root->operand(1))); - EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_EQ(CountCopies(*body), 1); + EXPECT_EQ(CountControlEdges(*body), 0); + + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast()))); + + auto add = body->root_instruction()->operand(0); + auto bcast = body->root_instruction()->operand(1)->operand(1); + ASSERT_EQ(add->opcode(), HloOpcode::kAdd); + ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); + + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(op::Copy(), op::Constant()), + op::Add(op::GetTupleElement(), op::Broadcast(op::Copy())))); + + // Both init indices need copies as they are constants. + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); } // Tests while body computation with read-only tuple element 0: @@ -768,33 +852,26 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) { // // CopyInsertion pass should not generate any copies for the while body. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { - auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); - auto while_hlo = BuildWhileInstruction(condition, body); + BuildWhileInstruction(condition, body); - const HloInstruction* old_init = while_hlo->operand(0); - HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - HloInstruction* new_root = body->root_instruction(); - const HloInstruction* new_init = while_hlo->operand(0); - - // No copies should be inserted in the body, so root should not be updated. - EXPECT_EQ(old_root, new_root); - // Both indices need copies, even though Index 0 is read-only, since both are - // constants, which must be copied. - EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + // No copies or control edges should be inserted. The body is legal as is. + EXPECT_EQ(CountCopies(*body), 0); + EXPECT_EQ(CountControlEdges(*body), 0); } // Same as above, but with two while loops, sharing entry parameters. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) { - auto condition1 = - module_->AddEmbeddedComputation(BuildConditionComputation()); - auto condition2 = - module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition1 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); + auto condition2 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body1 = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); auto body2 = module_->AddEmbeddedComputation( @@ -812,30 +889,46 @@ TEST_F(WhileCopyInsertionTest, loop_state_shape_, condition1, body1, loop_init)); auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition2, body2, loop_init)); - module_->AddEntryComputation(builder.Build()); + + // Add a couple elements from each of the while so both whiles are live. + auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); + auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); + + auto entry = module_->AddEntryComputation(builder.Build()); InsertCopies(module_.get()); - // Both while loops alias iter_param, since index 0 is read-only in the body. - EXPECT_EQ(while_hlo1->operand(0)->operand(0), - while_hlo2->operand(0)->operand(0)); - EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_param); + // Neither body should have any copies or control edges in them. + EXPECT_EQ(CountCopies(*body1), 0); + EXPECT_EQ(CountCopies(*body2), 0); + EXPECT_EQ(CountControlEdges(*body1), 0); + EXPECT_EQ(CountControlEdges(*body2), 0); - // Each while loop gets its own copy of data_param, since index 1 is not - // read-only in the body. + // Only two copies should be necessary. Each of the whiles should have + // a copy of tuple element 1 (init value is a parameter, and the element is + // not non-read-only) so each of the while bodies gets its own buffer to write + // element 1 into. + EXPECT_EQ(CountCopies(*entry), 2); + + EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); + EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); + + // The two copies of element 1 should be different. EXPECT_NE(while_hlo1->operand(0)->operand(1), while_hlo2->operand(0)->operand(1)); - EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_param)); - EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_param)); } // Same as above, but with two while loops, sharing non-parameters. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly_TwoLoops_NonParams) { - auto condition1 = - module_->AddEmbeddedComputation(BuildConditionComputation()); - auto condition2 = - module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition1 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); + auto condition2 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body1 = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); auto body2 = module_->AddEmbeddedComputation( @@ -858,21 +951,28 @@ TEST_F(WhileCopyInsertionTest, loop_state_shape_, condition1, body1, loop_init)); auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition2, body2, loop_init)); - module_->AddEntryComputation(builder.Build()); + + // Add a couple elements from each of the while so both whiles are not dead. + auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); + auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); + auto entry = module_->AddEntryComputation(builder.Build()); InsertCopies(module_.get()); - // No copies of iter_value are necessary, since index 0 is read-only in both - // while bodies. - EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_value); - EXPECT_EQ(while_hlo2->operand(0)->operand(0), iter_value); + // Ideally only one copy should be necessary. One of the whiles should + // have a copy of tuple element 1 (the non-read-only element) so each of the + // while bodies gets its own buffer to write element 1 into. However, the + // analysis isn't perfect and adds an additional copy of element 0. + EXPECT_EQ(CountCopies(*entry), 2); - // Each while loop gets its own copy of data_value, since index 1 is not - // read-only in the body. - EXPECT_NE(while_hlo1->operand(0)->operand(1), - while_hlo2->operand(0)->operand(1)); - EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_value)); - EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_value)); + EXPECT_THAT(while_hlo1->operand(0), + op::Tuple(op::Exp(), op::Copy(op::Exp()))); + EXPECT_THAT(while_hlo2->operand(0), + op::Tuple(op::Exp(), op::Copy(op::Exp()))); } // Tests while body computation with nested tuple elements: @@ -905,18 +1005,34 @@ TEST_F(WhileCopyInsertionTest, // Tuple // new root // TEST_F(WhileCopyInsertionTest, NestedTupleElements) { - auto condition = - module_->AddEmbeddedComputation(BuildConditionComputation(true)); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(nested_loop_state_shape_)); auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation()); BuildWhileInstruction(condition, body, true); - HloInstruction* old_root = body->root_instruction(); + // HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - EXPECT_THAT(body->root_instruction(), - op::Tuple(old_root->operand(0), - op::Tuple(old_root->operand(1)->operand(0), - op::Copy(old_root->operand(1)->operand(1))))); + // The only copy necessary is for the kReverse as it cannot be done + // in-place (instruction can share buffer with operand). The other elements of + // the loop state are kAdd instructions which can be done in-place. + EXPECT_EQ(CountCopies(*body), 1); + + // Each element of the init needs a copy as all are constants. + EXPECT_EQ(CountCopies(*module_), 4); + + // Either the kReverse itself must be copied or the operand of the kReverse + // must be copied. + if (body->root_instruction()->operand(1)->operand(1)->opcode() == + HloOpcode::kCopy) { + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse())))); + } else { + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy())))); + } } // Tests while init instruction which points-to a constant. @@ -927,11 +1043,13 @@ TEST_F(WhileCopyInsertionTest, NestedTupleElements) { // TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { auto while_hlo = BuildWhileInstruction_InitPointsToConstant(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); + EXPECT_EQ(CountCopies(*module_), 2); - EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); } // Tests while init instruction which points-to a parameter. @@ -942,11 +1060,13 @@ TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { // TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { auto while_hlo = BuildWhileInstruction_InitPointsToParameter(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); + EXPECT_EQ(CountCopies(*module_), 2); - EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter()))); } // Tests while init instruction which has an ambiguous points-to set. @@ -975,15 +1095,34 @@ TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { // TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous(); - auto old_init = while_hlo->operand(0); - InsertCopies(module_.get()); - EXPECT_THAT( - while_hlo->operand(0), - op::Tuple( - op::Copy(old_init->operand(0)), - op::Tuple(op::Copy(op::GetTupleElement(old_init->operand(1))), - op::Copy(op::GetTupleElement(old_init->operand(1)))))); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*module_), 4); + // The entry computation requires three copies to resolve the ambiguity of two + // init elements and the constant passed in as one of the init elements. + EXPECT_EQ(CountCopies(*module_->entry_computation()), 3); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), + op::Tuple(op::Copy(op::GetTupleElement()), + op::Copy(op::GetTupleElement())))); + + // The body requires one copy because the buffer set is not distinct: the + // result of one of the adds is written into two elements of the output of the + // loop body. Either element might be copied. + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); + if (while_hlo->while_body() + ->root_instruction() + ->operand(1) + ->operand(0) + ->opcode() == HloOpcode::kCopy) { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); + } else { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); + } } // Tests while init instruction which has a non-distinct points-to set. @@ -1011,13 +1150,43 @@ TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { // TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); - EXPECT_THAT(while_hlo->operand(0), - op::Tuple(op::Copy(old_init->operand(0)), - op::Tuple(op::Copy(old_init->operand(1)->operand(0)), - op::Copy(old_init->operand(1)->operand(0))))); + // The entry computation requires two copies to resolve the non-disinctness of + // two init elements and the constant passed in as one of the init + // elements. Either element can be copied for the distinctness issue. + EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); + if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() == + HloOpcode::kCopy) { + EXPECT_THAT( + while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), + op::Tuple(op::Copy(op::Broadcast()), op::Broadcast()))); + } else { + EXPECT_THAT( + while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), + op::Tuple(op::Broadcast(), op::Copy(op::Broadcast())))); + } + + // The body requires one copy because the buffer set is not distinct: the + // result of one of the adds is written into two elements of the output of the + // loop body. Either element might be copied. + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); + if (while_hlo->while_body() + ->root_instruction() + ->operand(1) + ->operand(0) + ->opcode() == HloOpcode::kCopy) { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); + } else { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); + } } // Tests while init instruction buffer which interferes with while result @@ -1031,11 +1200,13 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { // TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { auto while_hlo = BuildWhileInstruction_InitPointsToInterfering(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*module_), 2); + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); - EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast()))); } // Tests while init instruction buffer which has a non-distinct points-to set: @@ -1044,18 +1215,21 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { // Parameter(F32, {8}))) // // where the second and third parameters are identical *and* the tuple shared -// by another while instruction.. +// by another while instruction. // // Verifies that the resulting point-to set is distinct in the resulting Tuple // (non-identical Copys). In other words, verifies that copy sharing does not // insert identical copies to the resulting tuple. TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { - auto condition1 = - module_->AddEmbeddedComputation(BuildConditionComputation()); - auto condition2 = - module_->AddEmbeddedComputation(BuildConditionComputation()); // Loop body that outputs tuple comprises two elements dependent on the init // tuple. + const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_, data_shape_}); + + auto condition1 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape)); + auto condition2 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape)); auto body1 = module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); auto body2 = @@ -1072,8 +1246,6 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { auto loop_init = builder.AddInstruction( HloInstruction::CreateTuple({iter_param, data_param, data_param})); - const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( - {induction_variable_shape_, data_shape_, data_shape_}); // Two while loops shares the same loop init tuple. auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( @@ -1081,43 +1253,478 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape, condition2, body2, loop_init)); - module_->AddEntryComputation(builder.Build()); + // Add add instruction so neither while is dead. + auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); + auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); - auto points_to_analysis = - TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + module_->AddEntryComputation(builder.Build()); - // Asserts that the init tuples before copy insertion is non-distinct. - ASSERT_FALSE( - points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct()); - ASSERT_FALSE( - points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct()); + InsertCopies(module_.get()); - auto old_init1 = while_hlo1->operand(0); - auto old_init2 = while_hlo2->operand(0); + // None of the bodies should have copies or control flow edges. + EXPECT_EQ(CountCopies(*body1), 0); + EXPECT_EQ(CountCopies(*body2), 0); - InsertCopies(module_.get()); + // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally + // these should not need to be copied before either while. However, copy + // insertion is not able to reason about the transparency of elements through + // while bodies in all circumstances so extra copies are added (b/xxx). + EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); EXPECT_THAT(while_hlo1->operand(0), - op::Tuple(op::Copy(old_init1->operand(0)), - op::Copy(old_init1->operand(1)), - op::Copy(old_init1->operand(2)))); - + op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); EXPECT_THAT(while_hlo2->operand(0), - op::Tuple(op::Copy(old_init2->operand(0)), - op::Copy(old_init2->operand(1)), - op::Copy(old_init2->operand(2)))); - - // Verifies the init tuples after copy insertion is distinct. - points_to_analysis = - TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - const auto& points_to1 = - points_to_analysis->GetPointsToSet(while_hlo1->operand(0)); - EXPECT_TRUE(points_to1.IsDistinct()); - - const auto& points_to2 = - points_to_analysis->GetPointsToSet(while_hlo2->operand(0)); - EXPECT_TRUE(points_to2.IsDistinct()); + op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); } +TEST_F(CopyInsertionTest, SwizzlingWhile) { + // Test a while instruction with a body which permutes its tuple parameter + // elements. + auto module = CreateNewModule(); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Body simply interchanges the two tuple elements in the loop state. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_1, body_element_0})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 6); + + // The loop state elements should be copied at the parameter and at the root + // with a control edge in between (see DeepCopyAndAddControlEdges). This is + // technically one more copy than is strictly necessary, but in order to have + // only three copies the copies of different loop state elements must be + // ordered with a control edge. + EXPECT_EQ(CountCopies(*body), 4); + EXPECT_EQ(CountControlEdges(*body), 2); + + EXPECT_THAT(body->root_instruction(), + op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy()))); + + EXPECT_EQ(CountCopies(*module->entry_computation()), 2); + EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); +} + +TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { + // Test a while instruction with a body which permutes its tuple parameter + // elements and applies one operation to one of the elements. The addition of + // the operation (instruction) on the element makes the live range of the + // respective input and output elements different than if the instruction were + // not there (as in the SwizzlingWhile test above). + auto module = CreateNewModule(); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Body interchanges the two tuple elements in the loop state and negates one + // of them. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, body_element_1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({negate, body_element_0})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 6); + + // The loop state elements should be copied at the parameter and at the root + // with a control edge in between (see DeepCopyAndAddControlEdges). + EXPECT_EQ(CountCopies(*body), 4); + EXPECT_EQ(CountControlEdges(*body), 2); + + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy()))); + + EXPECT_EQ(CountCopies(*module->entry_computation()), 2); + EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); +} + +TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { + // Test a while instruction with a body which permutes it's tuple parameter + // elements similar to SwizzlinWhile above. However, in this test the input to + // the while body is a single constant (both loop state elements are the same + // constant). This means no copies are necessary because both loop state + // elements are the same so interchanging them is a no-op. + auto module = CreateNewModule(); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Body simply interchanges the two tuple elements in the loop state. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_1, body_element_0})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); + builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 2); + EXPECT_EQ(CountCopies(*body), 0); + + EXPECT_EQ(CountCopies(*module->entry_computation()), 2); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(), op::Copy())); +} + +TEST_F(CopyInsertionTest, SequentialWhiles) { + // Construct a computation with a series of sequential while instructions + // containing four loop state elements: + // + // element 0 is passed to each while directly from an entry parameter. + // + // element 1 is passed transparently in series through all the while bodies. + // + // element 2 is negated in each while body. (in-place possible) + // + // element 3 is reversed in each while body. (in-place not possible) + // + const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); + const Shape loop_state_shape = ShapeUtil::MakeTupleShape( + {element_shape, element_shape, element_shape, element_shape}); + + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param_0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, element_shape, "param_0")); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, element_shape, "param_1")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, element_shape, "param_2")); + auto param_3 = builder.AddInstruction( + HloInstruction::CreateParameter(3, element_shape, "param_3")); + + // The number of sequential kWhile instructions. + const int kNumWhiles = 3; + + HloInstruction* prev_element_1 = param_1; + HloInstruction* prev_element_2 = param_2; + HloInstruction* prev_element_3 = param_3; + + // Vector containing all of the while instructions. + std::vector whiles; + for (int i = 0; i < kNumWhiles; ++i) { + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 1)); + auto body_element_2 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 2)); + auto body_element_3 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 3)); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + element_shape, HloOpcode::kNegate, body_element_2)); + auto reverse = body_builder.AddInstruction( + HloInstruction::CreateReverse(element_shape, body_element_3, {0})); + body_builder.AddInstruction(HloInstruction::CreateTuple( + {body_element_0, body_element_1, negate, reverse})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto while_init = builder.AddInstruction(HloInstruction::CreateTuple( + {param_0, prev_element_1, prev_element_2, prev_element_3})); + + auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, condition, body, while_init)); + whiles.push_back(xla_while); + if (i != kNumWhiles - 1) { + prev_element_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1)); + prev_element_2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2)); + prev_element_3 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3)); + } + } + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + // Each while body has one copy. And each loop state element is copied once in + // the entry computation. + EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles); + + // Each while body should have exactly one copy for element three which is an + // op (kReverse) which cannot be done in place. + for (const HloInstruction* xla_while : whiles) { + EXPECT_EQ(CountCopies(*xla_while->while_body()), 1); + } + + EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(), + op::Copy(), op::Copy())); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(), + op::GetTupleElement())); +} + +TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { + // Test a while body and condition which are each simply a constant (root of + // computation is a constant). The body constant should be copied. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param_0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); + + auto body_builder = HloComputation::Builder("body"); + body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0)); + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 2); + + EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter())); + EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant())); + EXPECT_THAT(condition->root_instruction(), op::Constant()); +} + +std::unique_ptr MakeTrivialCondition(const Shape& shape) { + auto builder = HloComputation::Builder("trivial_condition"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "loop_state")); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNot, constant)); + return builder.Build(); +} + +std::unique_ptr MakeBenchmarkWhileBody() { + auto builder = HloComputation::Builder("benchmark_loop_body"); + const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); + HloInstruction* element_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, 0)); + HloInstruction* element_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, 1)); + HloInstruction* element_2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, 2)); + + HloInstruction* rev_1 = builder.AddInstruction( + HloInstruction::CreateReverse(element_shape, element_1, {0})); + HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary( + element_shape, HloOpcode::kAdd, element_1, element_2)); + + builder.AddInstruction( + HloInstruction::CreateTuple({element_0, rev_1, add_1_2})); + return builder.Build(); +} + +void BM_SequentialWhiles(int num_iters, int num_whiles) { + // This benchmark constructs a chain of sequential while instructions. + tensorflow::testing::StopTiming(); + for (int i = 0; i < num_iters; ++i) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), + config); + + auto builder = HloComputation::Builder("BM_SequentialWhiles"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {42}), "x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {42}), "y")); + HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(F32, {42}), "z")); + HloInstruction* init = + builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); + + HloInstruction* prev_loop_state = init; + for (int w = 0; w < num_whiles; ++w) { + HloComputation* condition = + module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); + HloComputation* body = + module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); + prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile( + init->shape(), condition, body, prev_loop_state)); + } + module.AddEntryComputation(builder.Build()); + + CopyInsertion copy_insertion; + + tensorflow::testing::StartTiming(); + ASSERT_IS_OK(copy_insertion.Run(&module).status()); + tensorflow::testing::StopTiming(); + + // The entry computation should have three copies, and each body has one. + ASSERT_EQ(CountCopies(module), 3 + num_whiles); + } +} + +void BM_ParallelWhiles(int num_iters, int num_whiles) { + // This benchmark constructs a fan-out of parallel while instructions. + tensorflow::testing::StopTiming(); + for (int i = 0; i < num_iters; ++i) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), + config); + + auto builder = HloComputation::Builder("BM_ParallelWhiles"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {42}), "x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {42}), "y")); + HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(F32, {42}), "z")); + HloInstruction* init = + builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); + + HloInstruction* sum = nullptr; + for (int w = 0; w < num_whiles; ++w) { + HloComputation* condition = + module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); + HloComputation* body = + module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); + + HloInstruction* xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(init->shape(), condition, body, init)); + + if (sum == nullptr) { + sum = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); + } else { + HloInstruction* element_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); + sum = builder.AddInstruction(HloInstruction::CreateBinary( + x->shape(), HloOpcode::kAdd, sum, element_0)); + } + } + module.AddEntryComputation(builder.Build()); + + CopyInsertion copy_insertion; + + tensorflow::testing::StartTiming(); + ASSERT_IS_OK(copy_insertion.Run(&module).status()); + tensorflow::testing::StopTiming(); + + // Each body receives of copy of two of the parameters (the corresponding + // elements in the body are modifed), and there is one copy in each body. + ASSERT_EQ(CountCopies(module), 3 * num_whiles); + } +} + +BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); +BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 6213baee2fa5c4af7c650d0be4af619deba2709a..b43597dca983151d59ec7aaba9887313191fc9bd 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -17,6 +17,7 @@ package_group( load(":build_defs.bzl", "runtime_copts") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") # Filegroup used to collect source files for dependency checking. filegroup( @@ -78,14 +79,16 @@ cc_library( deps = [ ":compiler_functor", ":conv_canonicalization", + ":cpu_copy_insertion", ":cpu_executable", ":cpu_instruction_fusion", + ":cpu_layout_assignment", ":cpu_options", ":cpu_parallelization_preparation", ":disassembler", + ":dot_op_emitter", ":ir_emission_utils", ":ir_emitter", - ":layout_assignment", ":parallel_cpu_executable", ":parallel_task_assignment", ":simple_orc_jit", @@ -101,13 +104,14 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", - "//tensorflow/compiler/xla/service:copy_insertion", + "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", @@ -156,21 +160,23 @@ cc_library( ":custom_call_target_registry", ":disassembler", ":external_constant_pool", + ":orc_jit_memory_mapper", ":runtime_conv2d", ":runtime_fork_join", ":runtime_matmul", ":runtime_single_threaded_conv2d", ":runtime_single_threaded_matmul", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - "@llvm//:core", "@llvm//:execution_engine", + "@llvm//:core", "@llvm//:mc", # fixdeps: keep "@llvm//:orc_jit", "@llvm//:support", "@llvm//:target", # fixdeps: keep - ], + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ] + ORC_JIT_MEMORY_MAPPER_TARGETS, ) cc_library( @@ -246,6 +252,8 @@ cc_library( ":dot_op_emitter", ":external_constant_pool", ":ir_emission_utils", + ":ir_function", + ":parallel_loop_emitter", ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", @@ -269,19 +277,54 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:ops", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@llvm//:code_gen", "@llvm//:core", "@llvm//:support", "@llvm//:target", ], ) +cc_library( + name = "ir_function", + srcs = ["ir_function.cc"], + hdrs = ["ir_function.h"], + deps = [ + ":ir_emission_utils", + ":shape_partition", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service/cpu:cpu_runtime", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "parallel_loop_emitter", + srcs = ["parallel_loop_emitter.cc"], + hdrs = ["parallel_loop_emitter.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "dot_op_emitter", srcs = ["dot_op_emitter.cc"], hdrs = ["dot_op_emitter.h"], deps = [ + ":cpu_options", ":cpu_runtime", - ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", @@ -290,8 +333,10 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:vector_support_library", "//tensorflow/core:lib", "@llvm//:core", ], @@ -608,14 +653,16 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", + "@llvm//:core", ], ) cc_library( - name = "layout_assignment", - srcs = ["layout_assignment.cc"], - hdrs = ["layout_assignment.h"], + name = "cpu_layout_assignment", + srcs = ["cpu_layout_assignment.cc"], + hdrs = ["cpu_layout_assignment.h"], deps = [ + ":dot_op_emitter", ":ir_emission_utils", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:computation_layout", @@ -625,11 +672,11 @@ cc_library( ) tf_cc_test( - name = "layout_assignment_test", + name = "cpu_layout_assignment_test", size = "small", - srcs = ["layout_assignment_test.cc"], + srcs = ["cpu_layout_assignment_test.cc"], deps = [ - ":layout_assignment", + ":cpu_layout_assignment", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -703,6 +750,7 @@ cc_library( srcs = ["parallel_task_assignment.cc"], hdrs = ["parallel_task_assignment.h"], deps = [ + ":dot_op_emitter", ":ir_emission_utils", ":shape_partition", "//tensorflow/compiler/xla/service:hlo", @@ -717,6 +765,7 @@ cc_library( hdrs = ["cpu_options.h"], deps = [ "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:lib", ], ) @@ -731,6 +780,48 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "orc_jit_memory_mapper", + srcs = ["orc_jit_memory_mapper.cc"], + hdrs = ["orc_jit_memory_mapper.h"], + deps = [ + "//tensorflow/core:lib", + "@llvm//:execution_engine", + ], +) + +cc_library( + name = "cpu_copy_insertion", + srcs = ["cpu_copy_insertion.cc"], + hdrs = ["cpu_copy_insertion.h"], + deps = [ + "//tensorflow/compiler/xla/service:copy_insertion", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "cpu_copy_insertion_test", + srcs = ["cpu_copy_insertion_test.cc"], + deps = [ + ":cpu_copy_insertion", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 44cd2171afdc6eecc22f3f920276a4d95f930573..2136aeb3877685373efaf5bf702a42b39a63f082 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -41,19 +41,17 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension(); auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension(); - int num_spatial_dims = dnums.spatial_dimensions_size(); - int num_dims = num_spatial_dims + 2; + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + const int64 num_dims = num_spatial_dims + 2; // A canonical convolution's dimension numbers need to satisfy the // following conditions (see cs/PotentiallyImplementedAsEigenConvolution). // - // - the input is in NHWC or NWHC order. - // - the kernel is in HWIO or WHIO order. - // - the spatial dimensions are in the same relative order in the input, - // kernel and output. + // - the input is in NHWC order. + // - the kernel is in HWIO order. // // For simplicity, as a first step, we reshape the input and filter to - // NHWC and HWIO order, respectively. This may lose precision but not + // NHWC and HWIO order, respectively. This may lose precision but won't // break the soundness. HloInstruction* input = hlo->mutable_operand(0); @@ -61,10 +59,10 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { std::vector new_input_dims(num_dims); new_input_dim_order[0] = input_batch_dim; new_input_dims[0] = input->shape().dimensions(input_batch_dim); - for (int i = 0; i < num_spatial_dims; ++i) { - new_input_dim_order[i + 1] = dnums.spatial_dimensions(i); + for (int64 i = 0; i < num_spatial_dims; ++i) { + new_input_dim_order[i + 1] = dnums.input_spatial_dimensions(i); new_input_dims[i + 1] = - input->shape().dimensions(dnums.spatial_dimensions(i)); + input->shape().dimensions(dnums.input_spatial_dimensions(i)); } new_input_dim_order[num_dims - 1] = input_feature_dim; new_input_dims[num_dims - 1] = @@ -80,7 +78,7 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { std::vector new_kernel_dim_order(num_dims); std::vector new_kernel_dims(num_dims); - for (int i = 0; i < num_spatial_dims; ++i) { + for (int64 i = 0; i < num_spatial_dims; ++i) { new_kernel_dim_order[i] = dnums.kernel_spatial_dimensions(i); new_kernel_dims[i] = kernel->shape().dimensions(dnums.kernel_spatial_dimensions(i)); @@ -98,14 +96,18 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { HloInstruction::CreateTranspose(new_kernel_shape, kernel, new_kernel_dim_order)); + std::vector new_output_dim_order(num_dims); std::vector new_conv_dims(num_dims); auto output_batch_dim = dnums.output_batch_dimension(); auto output_feature_dim = dnums.output_feature_dimension(); + new_output_dim_order[0] = output_batch_dim; new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim); - for (int i = 0; i < num_spatial_dims; ++i) { + for (int64 i = 0; i < num_spatial_dims; ++i) { + new_output_dim_order[i + 1] = dnums.output_spatial_dimensions(i); new_conv_dims[i + 1] = - hlo->shape().dimensions(dnums.spatial_dimensions(i)); + hlo->shape().dimensions(dnums.output_spatial_dimensions(i)); } + new_output_dim_order[num_dims - 1] = output_feature_dim; new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim); Shape new_conv_shape = ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims); @@ -113,9 +115,10 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { ConvolutionDimensionNumbers new_dnums; new_dnums.set_input_batch_dimension(0); new_dnums.set_output_batch_dimension(0); - for (int i = 0; i < num_spatial_dims; ++i) { - new_dnums.add_spatial_dimensions(i + 1); + for (int64 i = 0; i < num_spatial_dims; ++i) { + new_dnums.add_input_spatial_dimensions(i + 1); new_dnums.add_kernel_spatial_dimensions(i); + new_dnums.add_output_spatial_dimensions(i + 1); } new_dnums.set_input_feature_dimension(num_dims - 1); new_dnums.set_output_feature_dimension(num_dims - 1); @@ -129,14 +132,11 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, hlo->window(), new_dnums)); - // kConvolution inherits the dimension mapping of its input, so we need to - // reshape the output back to the shape of the original convolution. This - // is done by apply the inverse permutation of the collapsing order of the - // input reshape. + // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( hlo, HloInstruction::CreateTranspose( hlo->shape(), new_conv, - InversePermutation(new_input_dim_order)))); + InversePermutation(new_output_dim_order)))); changed = true; } } diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index d593ba26b655d00a0f0f0b9a94c9e62fa1835080..968f53d5c706651d2a470a853e0e9b601c0ed2df 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -69,8 +69,10 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(1); dnums.set_output_batch_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); dnums.set_input_feature_dimension(0); dnums.set_output_feature_dimension(0); dnums.add_kernel_spatial_dimensions(2); @@ -125,8 +127,10 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index f46764cba0ad6ef174a89951c251613c69b4b083..55e7c7bc2ca05991ac6dd53bf48bc9fd30f52601 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -46,27 +46,30 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -196,28 +199,35 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { class CollectProfileCandidates : public DfsHloVisitorWithDefault { public: static StatusOr> - GetCandidatesForComputation(HloComputation* computation) { + GetCandidatesForComputation( + HloComputation* computation, + const std::unordered_map& + assigned_indices) { std::unordered_map hlo_to_profile_idx; CollectProfileCandidates profile_candidates_for_computation( - &hlo_to_profile_idx); + &hlo_to_profile_idx, assigned_indices); TF_RETURN_IF_ERROR( computation->Accept(&profile_candidates_for_computation)); return hlo_to_profile_idx; } private: - explicit CollectProfileCandidates( - std::unordered_map* hlo_to_profile_idx) - : hlo_to_profile_idx_(hlo_to_profile_idx) {} + CollectProfileCandidates( + std::unordered_map* hlo_to_profile_idx, + const std::unordered_map& assigned_indices) + : hlo_to_profile_idx_(hlo_to_profile_idx), + assigned_indices_(assigned_indices) {} Status DefaultAction(HloInstruction* hlo_instruction) override { - hlo_to_profile_idx_->insert({hlo_instruction, hlo_to_profile_idx_->size()}); + hlo_to_profile_idx_->insert( + {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)}); return Status::OK(); } Status HandleCall(HloInstruction* call) override { TF_RETURN_IF_ERROR(DefaultAction(call)); - CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_); + CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_, + assigned_indices_); TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call)); return Status::OK(); } @@ -231,17 +241,20 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { Status HandleWhile(HloInstruction* xla_while) override { TF_RETURN_IF_ERROR(DefaultAction(xla_while)); - CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_); + CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_, + assigned_indices_); TF_RETURN_IF_ERROR( xla_while->while_condition()->Accept(&candidates_for_condition)); - CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_); + CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_, + assigned_indices_); TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body)); return Status::OK(); } std::unordered_map* hlo_to_profile_idx_; + const std::unordered_map& assigned_indices_; }; } // namespace @@ -261,7 +274,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); - + pipeline.AddPass(); pipeline.AddPass(); { auto& pass = @@ -276,7 +289,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, - /*enable_dot_simplification=*/false); + /*enable_dot_strength_reduction=*/false); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -305,8 +318,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pipeline.AddPass>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, - /*enable_dot_simplification=*/false); + /*enable_dot_strength_reduction=*/false); pipeline.AddPass(/*is_layout_sensitive=*/true); + pipeline.AddPass(BF16, F32); // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 @@ -331,15 +345,16 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // (and sometime after) copy insertion, to avoid dead code from interfering // with the rewrites. pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); if (options::CpuParallelBackendRequested(module->config())) { // Re-run the outlining, in case any copies were inserted into the entry // computation. pipeline.AddPass(max_parallelism, ShapeSizeBytesFunction()); + pipeline.AddPass(); } pipeline.AddPass(); - pipeline.AddPass(); return pipeline.Run(module).status(); } @@ -425,11 +440,25 @@ Status InitializeModuleHooks( } // namespace -StatusOr> CpuCompiler::Compile( - std::unique_ptr module, se::StreamExecutor* stream_exec) { +StatusOr> CpuCompiler::RunHloPasses( + std::unique_ptr module, + perftools::gputools::StreamExecutor* /*stream_exec*/) { + VLOG(2) << "Before optimization:"; + XLA_VLOG_LINES(2, module->ToString()); + + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); + + VLOG(2) << "After optimization:"; + XLA_VLOG_LINES(2, module->ToString()); + return std::move(module); +} + +StatusOr> CpuCompiler::RunBackend( + std::unique_ptr module, + perftools::gputools::StreamExecutor* stream_exec) { const string timer_message = "Compiling [" + module->name() + "] for CPU using JIT"; - ScopedLoggingTimer compiling_timer(timer_message, 1); + XLA_SCOPED_LOGGING_TIMER(timer_message); VLOG(1) << "Compiling: " << module->name(); TF_RET_CHECK(stream_exec != nullptr); @@ -443,11 +472,11 @@ StatusOr> CpuCompiler::Compile( &pre_optimization_ir_hook, &post_optimization_ir_hook)); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = MakeUnique(); + auto llvm_context = xla::MakeUnique(); auto llvm_module = - MakeUnique("__compute_module", *llvm_context); + xla::MakeUnique("__compute_module", *llvm_context); - auto jit = MakeUnique( + auto jit = xla::MakeUnique( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), @@ -457,20 +486,29 @@ StatusOr> CpuCompiler::Compile( llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - VLOG(2) << "Before optimization:"; - XLA_VLOG_LINES(2, module->ToString()); - - TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); - - VLOG(2) << "After optimization:"; - XLA_VLOG_LINES(2, module->ToString()); - HloComputation* computation = module->entry_computation(); std::unordered_map hlo_to_profile_idx; + std::unique_ptr hlo_profile_index_map; + std::unique_ptr hlo_profile_printer; if (module->config().hlo_profiling_enabled()) { + hlo_profile_index_map = MakeUnique(*module); + TF_ASSIGN_OR_RETURN( hlo_to_profile_idx, - CollectProfileCandidates::GetCandidatesForComputation(computation)); + CollectProfileCandidates::GetCandidatesForComputation( + computation, hlo_profile_index_map->instruction_to_profile_idx())); + + auto shape_size_bytes = [](const Shape& shape) { + // On the cpu, opaques are pointers. + if (ShapeUtil::IsOpaque(shape)) { + return static_cast(sizeof(void*)); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + }; + + HloCostAnalysis cost_analysis(shape_size_bytes); + hlo_profile_printer = + CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis); } std::unique_ptr cpu_executable; @@ -493,9 +531,9 @@ StatusOr> CpuCompiler::Compile( // uses data dependencies for determining order. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run(module.get(), - MakeUnique(module.get()), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run( + module.get(), xla::MakeUnique(module.get()), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -522,7 +560,7 @@ StatusOr> CpuCompiler::Compile( const void* data = instruction->literal().InternalData(); int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); auto iter = aligned_constants.emplace( - instruction, MakeUnique(size)); + instruction, xla::MakeUnique(size)); CHECK_EQ(iter.second, true); unsigned char* aligned_data = iter.first->second.get(); memcpy(aligned_data, data, size); @@ -536,9 +574,17 @@ StatusOr> CpuCompiler::Compile( parallel_computations.emplace(to_apply, instruction); } + // We always profile the entire computation as a whole, even if hlo + // profiling is disabled. When hlo profiling is diabled, we pass in a + // profile counter array of just one element, which corresponds to the whole + // computation. + size_t entry_computation_profile_idx = + hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor( + *module->entry_computation()) + : 0; IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine(), - jit->external_constant_pool()); + hlo_to_profile_idx, entry_computation_profile_idx, + jit->target_machine(), jit->external_constant_pool()); std::unique_ptr> function_names( new HloInstructionMap()); @@ -557,7 +603,7 @@ StatusOr> CpuCompiler::Compile( llvm::Function * ir_function, ir_emitter.EmitComputation( embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/computation_is_parallel, + /*is_top_level_computation=*/computation_is_parallel, /*instruction_order=*/nullptr)); // If this computation is parallel, remember it in the function name map. // This way we know what function to execute when we try to run code for @@ -578,8 +624,8 @@ StatusOr> CpuCompiler::Compile( jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new ParallelCpuExecutable( std::move(jit), std::move(assignment), std::move(module), - std::move(function_names), std::move(hlo_to_profile_idx), - std::move(aligned_constants))); + std::move(function_names), std::move(aligned_constants), + std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { static_cast(*cpu_executable) @@ -599,10 +645,10 @@ StatusOr> CpuCompiler::Compile( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module.get(), - MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run(module.get(), + xla::MakeUnique( + module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -612,13 +658,23 @@ StatusOr> CpuCompiler::Compile( TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( proto, xla_dump_hlo_proto_to, module->name())); } + // We always profile the entire computation as a whole, even if hlo + // profiling is disabled. When hlo profiling is diabled, we pass in a + // profile counter array of just one element, which corresponds to the whole + // computation. + size_t entry_computation_profile_idx = + hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor( + *module->entry_computation()) + : 0; + // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. + IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine(), - jit->external_constant_pool()); + hlo_to_profile_idx, entry_computation_profile_idx, + jit->target_machine(), jit->external_constant_pool()); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -629,7 +685,7 @@ StatusOr> CpuCompiler::Compile( ir_emitter .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/false, + /*is_top_level_computation=*/false, &module_sequence.at(embedded_computation)) .status()); } @@ -638,7 +694,7 @@ StatusOr> CpuCompiler::Compile( TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, ir_emitter.EmitComputation(computation, function_name_prefix, - /*is_entry_computation=*/true, + /*is_top_level_computation=*/true, &module_sequence.at(computation))); string function_name = llvm_ir::AsString(entry_function->getName()); @@ -651,7 +707,7 @@ StatusOr> CpuCompiler::Compile( jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new CpuExecutable( std::move(jit), std::move(assignment), std::move(module), function_name, - std::move(hlo_to_profile_idx))); + std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { static_cast(*cpu_executable) @@ -663,13 +719,6 @@ StatusOr> CpuCompiler::Compile( return std::move(cpu_executable); } -StatusOr>> CpuCompiler::Compile( - std::vector> modules, - std::vector> stream_execs) { - return Unimplemented( - "Compilation of multiple HLO modules is not yet supported on CPU."); -} - StatusOr>> CpuCompiler::CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& aot_options) { @@ -778,7 +827,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run( - module, MakeUnique(module, module_sequence), + module, + xla::MakeUnique(module, module_sequence), BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. @@ -792,9 +842,13 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, proto, xla_dump_hlo_proto_to, module->name())); } - IrEmitter ir_emitter(*module, *assignment, &llvm_module, - /*hlo_to_profile_idx=*/nullptr, target_machine.get(), - /*external_constant_pool=*/nullptr); + IrEmitter ir_emitter( + *module, *assignment, &llvm_module, + /*hlo_to_profile_idx=*/ + std::unordered_map{}, + /*entry_computation_profile_idx=*/tensorflow::gtl::nullopt, + target_machine.get(), + /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -805,7 +859,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, ir_emitter .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/false, + /*is_top_level_computation=*/false, &module_sequence.at(embedded_computation)) .status()); } @@ -813,7 +867,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, ir_emitter.EmitComputation(computation, entry_point_name, - /*is_entry_computation=*/true, + /*is_top_level_computation=*/true, &module_sequence.at(computation))); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index d09130247421b11d6d4879466f39b89167eb9564..ebed7058d8f7968c6e03ef90d0da6b2325037eb0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -109,14 +109,20 @@ class CpuCompiler : public LLVMCompiler { CpuCompiler(); ~CpuCompiler() override {} - StatusOr> Compile( + // Bring in + // StatusOr>> Compile( + // std::vector> modules, + // std::vector> + // stream_execs) + using LLVMCompiler::Compile; + + StatusOr> RunHloPasses( std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; - StatusOr>> Compile( - std::vector> modules, - std::vector> - stream_execs) override; + StatusOr> RunBackend( + std::unique_ptr module, + perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> CompileAheadOfTime(std::vector> modules, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.cc new file mode 100644 index 0000000000000000000000000000000000000000..baaacd2ecc9611946678f71ac36ef787ecb57b4e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.cc @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr CpuCopyInsertion::Run(HloModule* module) { + CopyInsertion generic_copy_insertion; + + TF_ASSIGN_OR_RETURN(bool generic_changed, generic_copy_insertion.Run(module)); + + // The CPU backend needs additional copies added due to deficiencies in + // buffer assignment. + TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed, + CopyInsertion::AddCopiesForBufferAssignment(module)); + + return generic_changed || buffer_assignment_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h new file mode 100644 index 0000000000000000000000000000000000000000..3313d1e6eb71bff39f509c3d24858568df786422 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Besides the modifications made by the generic xla::CopyInsertion, this +// CPU-specific copy insertion pass also adds copies to values live out of +// computations satisfying certain conditions (defined by constant or parameter, +// etc). This is necessary because of deficiencies of buffer +// assignment. Specifically, buffer assignment is computation-scoped and does +// not recognized aliasing between arguments and outputs of computations. +// +// TODO(b/62548313): Remove this when buffer assignment is smarter +// (module-scoped). +class CpuCopyInsertion : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "copy-insertion"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a05a26941786cbf404c4685abb098c9ac8caaa09 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +int64 CountCopies(const HloComputation& computation) { + int64 count = 0; + for (const auto& instruction : computation.instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + count++; + } + } + return count; +} + +int64 CountCopies(const HloModule& module) { + int64 count = 0; + for (const auto& computation : module.computations()) { + count += CountCopies(*computation); + } + return count; +} + +class CpuCopyInsertionTest : public HloTestBase { + protected: + void InsertCopies(HloModule* module) { + CpuCopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module).status()); + } + + const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { + // Test a while body and condition which are each simply a constant (root of + // computation is a constant). Each constant should be copied. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param_0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); + + auto body_builder = HloComputation::Builder("body"); + body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0)); + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 3); + + EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter())); + EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant())); + EXPECT_THAT(condition->root_instruction(), op::Copy(op::Constant())); +} + +TEST_F(CpuCopyInsertionTest, TupleCall) { + // Test a kCall instruction which calls a computation which produces a three + // element tuple: one is a constant, one is a parameter, and one is produced + // in the computation. The constant and parameter should be copied. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_}); + + auto sub_builder = HloComputation::Builder("subcomputation"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto constant = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, sub_param, constant)); + sub_builder.AddInstruction( + HloInstruction::CreateTuple({sub_param, constant, add})); + HloComputation* subcomputation = + module->AddEmbeddedComputation(sub_builder.Build()); + + builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape, {param}, subcomputation)); + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*subcomputation), 2); + EXPECT_THAT(subcomputation->root_instruction(), + op::Tuple(op::Copy(op::Parameter()), op::Copy(op::Constant()), + op::Add())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index f62353bee7b1058dc237169b70341c33ab19fc52..e956f478b86d9816615e2902f5bbeae6d6384162 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/host/host_stream.h" namespace se = ::perftools::gputools; @@ -54,11 +55,12 @@ CpuExecutable::CpuExecutable( std::unique_ptr assignment, std::unique_ptr hlo_module, const string& entry_function_name, - std::unordered_map hlo_to_profile_idx) - : Executable(std::move(hlo_module)), + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + std::move(hlo_profile_index_map)), jit_(std::move(jit)), - assignment_(std::move(assignment)), - hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) { + assignment_(std::move(assignment)) { // Resolve symbols in the constructor rather than at execution time to avoid // races because FindSymbol is not thread safe. llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name); @@ -147,8 +149,9 @@ Status CpuExecutable::ExecuteComputeFunction( tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { std::vector argument_buffers; - for (int i = 0; i < arguments.size(); ++i) { - argument_buffers.push_back(arguments[i]->buffer(/*index=*/{})); + argument_buffers.reserve(arguments.size()); + for (const auto* argument : arguments) { + argument_buffers.push_back(argument->buffer(/*index=*/{})); } return ExecuteComputeFunction(run_options, argument_buffers, buffers, hlo_execution_profile); @@ -181,9 +184,16 @@ Status CpuExecutable::ExecuteComputeFunction( uint64 start_micros = tensorflow::Env::Default()->NowMicros(); // Allocate profiling counters for each hlo instruction that we would like to - // profile. Allocate an additional profile counter for the entire - // computation. - std::vector profile_counters(hlo_to_profile_idx_.size() + 1); + // profile. Even when not Hlo profiling, we allocate a counter for the entire + // computation, which we use to update ExecutionProfile below. + std::vector* profile_counters = nullptr; + std::vector profile_counter_for_entry_computation; + if (hlo_execution_profile) { + profile_counters = hlo_execution_profile->mutable_profile_counters(); + } else { + profile_counters = &profile_counter_for_entry_computation; + profile_counter_for_entry_computation.push_back(0); + } // Call the computation function following the calling convention. std::vector buffer_pointers; @@ -198,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction( VLOG(3) << tensorflow::strings::Printf( " func(void* result, void* params[%zu], void* temps[%zu], " "uint64 profile_counters[%zu])", - args_array.size(), buffer_pointers.size(), profile_counters.size()); + args_array.size(), buffer_pointers.size(), profile_counters->size()); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); @@ -210,11 +220,11 @@ Status CpuExecutable::ExecuteComputeFunction( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", - profile_counters.data()); + profile_counters->data()); } compute_function_(result_buffer, run_options, args_array.data(), - buffer_pointers.data(), profile_counters.data()); + buffer_pointers.data(), profile_counters->data()); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -223,20 +233,46 @@ Status CpuExecutable::ExecuteComputeFunction( const double nanoseconds = (end_micros - start_micros) * 1000.0; execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - // The last profile counter is used for the computation as a whole. - execution_profile_.set_compute_cycle_count(profile_counters.back()); + if (hlo_execution_profile) { + execution_profile_.set_compute_cycle_count( + hlo_execution_profile->total_cycles_executed( + *module().entry_computation())); + } else { + execution_profile_.set_compute_cycle_count(profile_counters->back()); + } } - if (hlo_execution_profile != nullptr) { - hlo_execution_profile->set_total_cycles_executed( - *module().entry_computation(), profile_counters.back()); + return Status::OK(); +} + +static void LogLiveAddresses( + const std::unordered_set& marked_addresses) { + VLOG(3) << "Live addresses in output marking found " + << marked_addresses.size() << " addresses:\n" + << tensorflow::str_util::Join( + marked_addresses, ", ", [](string* out, const void* address) { + tensorflow::strings::StrAppend( + out, tensorflow::strings::Printf("%p", address)); + }); +} - for (auto hlo_prof_idx : hlo_to_profile_idx_) { - const HloInstruction* hlo = hlo_prof_idx.first; - uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; - hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken); +static Status DeallocateTempBuffers( + DeviceMemoryAllocator* allocator, se::Stream* stream, + tensorflow::gtl::ArraySlice buffers, + const std::unordered_set& marked_addresses) { + // Keep those marked live because they are referenced by the output of the + // computation and are needed by the service. They will be deallocated by the + // service. + for (size_t i = 0; i < buffers.size(); ++i) { + se::DeviceMemoryBase alloc = buffers[i]; + if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) { + VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" + << alloc.opaque() << "]"; + TF_RETURN_IF_ERROR( + allocator->Deallocate(stream->parent()->device_ordinal(), &alloc)); } } + return Status::OK(); } @@ -262,26 +298,9 @@ StatusOr CpuExecutable::ExecuteOnStream( MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), &marked_addresses); - VLOG(3) << "Live addresses in output marking found " - << marked_addresses.size() << " addresses:\n" - << tensorflow::str_util::Join( - marked_addresses, ", ", [](string* out, const void* address) { - tensorflow::strings::StrAppend( - out, tensorflow::strings::Printf("%p", address)); - }); - - // Computation is done - deallocate temp buffers. Keep those marked live - // because they are referenced by the output of the computation and are needed - // by the service. They will be deallocated by the service. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate( - stream->parent()->device_ordinal(), &alloc)); - } - } + LogLiveAddresses(marked_addresses); + TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, + marked_addresses)); return top_level_output; } @@ -359,9 +378,44 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { - // TODO(b/30671675): Implement asynchronous execution mode. - return Unimplemented( - "Asynchronous execution on stream is not yet supported on CPU."); + if (hlo_profiling_enabled()) { + return Unimplemented( + "Asynchronous execution on stream with hlo profiling is not yet " + "supported on CPU."); + } + + auto* host_stream = dynamic_cast( + run_options->stream()->implementation()); + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + std::vector buffers(assignment_->Allocations().size()); + + TF_RETURN_IF_ERROR(AllocateBuffers( + memory_allocator, stream->parent()->device_ordinal(), &buffers)); + + // Mark the buffers that are actually live (used in the output) when the + // computation finishes executing. + std::unordered_set marked_addresses; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelOutputSlice()); + se::DeviceMemoryBase top_level_output = buffers[result_slice.index()]; + MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), + &marked_addresses); + + LogLiveAddresses(marked_addresses); + + host_stream->EnqueueTask([this, run_options, arguments, buffers, + marked_addresses, memory_allocator, stream]() { + // Failing a CHECK here is not great, but I don't see an obvious way to + // return a failed Status asynchronously. + TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments, + buffers, + /*hlo_execution_profile=*/nullptr)); + TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers, + marked_addresses)); + }); + + return top_level_output; } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { @@ -377,9 +431,5 @@ const PointsToSet& CpuExecutable::GetRootPointsToSet() const { module().entry_computation()->root_instruction()); } -std::unique_ptr CpuExecutable::CreateCostAnalysis() const { - return MakeUnique(ShapeSizeBytes); -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 238bc9b46ae2bf1b519eaf137d9ae063e769bd2e..17ee2d673ee7cde1847bf29e2399e6033cb7e30e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -47,12 +47,12 @@ namespace cpu { // architecture, so JIT-ed code and host code share the same ABI. class CpuExecutable : public Executable { public: - CpuExecutable( - std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, - const string& entry_function_name, - std::unordered_map hlo_to_profile_idx); + CpuExecutable(std::unique_ptr jit, + std::unique_ptr assignment, + std::unique_ptr hlo_module, + const string& entry_function_name, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map); ~CpuExecutable() override {} StatusOr ExecuteOnStream( @@ -85,12 +85,10 @@ class CpuExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); - std::unique_ptr CreateCostAnalysis() const override; - // Type of the computation function we expect in the JIT. using ComputeFunctionType = void (*)( void* /*result*/, const ExecutableRunOptions* /*run_options*/, - const void** /*args*/, void** /*temps*/, uint64* /*profile_counters*/); + const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/); const ComputeFunctionType& compute_function() const { return compute_function_; @@ -145,9 +143,6 @@ class CpuExecutable : public Executable { // Entry function name for the computation. const string entry_function_name_; - // Maps HLOs to their index into the profile counter array. - const std::unordered_map hlo_to_profile_idx_; - TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index f87ee3cecd932faac140636a3db7cd4aa0371b85..482e04052d5a914eab0e5bff2c7a83f3b698052f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -26,7 +26,7 @@ int64 BytesInDimension(const Shape& shape, int64 dimension) { shape.dimensions(dimension); } -bool IsFusile(const HloInstruction& hlo) { +bool CanBeLoopFused(const HloInstruction& hlo) { // These are the only ones we fuse since we rely on effective elemental IR // generation. return hlo.IsElementwise() || // @@ -42,6 +42,23 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } +bool IsMatrixVectorDot(const HloInstruction* hlo) { + const Shape& hlo_shape = hlo->shape(); + return hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 && + (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); +} + +bool CanBeOutputFused(const HloInstruction* producer, + const HloInstruction* consumer) { + return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) && + producer->user_count() == 1; +} + +bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) { + return consumer->opcode() == HloOpcode::kAdd && + (CanBeOutputFused(consumer->operand(0), consumer) || + CanBeOutputFused(consumer->operand(1), consumer)); +} } // namespace bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, @@ -52,7 +69,15 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, constexpr int kFusionThresholdBytes = 16 * 1024; - if (!IsFusile(*producer)) { + if (CanBeOutputFused(producer, consumer)) { + return true; + } + + if (CanBeOutputFusedIntoSomeOperand(producer)) { + return false; + } + + if (!CanBeLoopFused(*producer)) { VLOG(2) << "Producer is not fusile."; return false; } @@ -108,16 +133,13 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } } - if (consumer->opcode() == HloOpcode::kFusion) { - // InstructionFusion::ShouldFuse above only allows kLoop and kInput fusions. - // The CPU backend does not create kInput fusions, so we only expect to see - // kLoop here. - CHECK(consumer->fusion_kind() == HloInstruction::FusionKind::kLoop); + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kLoop) { VLOG(2) << "Fusing: consumer is a fusion node."; return true; } - if (IsFusile(*consumer)) { + if (CanBeLoopFused(*consumer)) { VLOG(2) << "Fusing: consumer is elementwise or fusile."; return true; } @@ -126,5 +148,11 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } +HloInstruction::FusionKind CpuInstructionFusion::ChooseKind( + const HloInstruction* producer, const HloInstruction* consumer) { + return CanBeOutputFused(producer, consumer) + ? HloInstruction::FusionKind::kOutput + : HloInstruction::FusionKind::kLoop; +} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h index 0eca4c3473e1454fe5dbd8bf855b4418cf553a94..07aff34974e0cfa6c7a129f82017b280fb1ccd59 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h @@ -30,6 +30,8 @@ class CpuInstructionFusion : public InstructionFusion { protected: bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; + HloInstruction::FusionKind ChooseKind( + const HloInstruction* producer, const HloInstruction* consumer) override; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index b9e4d006d77ae76e33ac51440349400ea4eff118..595c3f55b321f47e2312b93e0c238c7637495d77 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -31,6 +31,14 @@ namespace { using InstructionFusionTest = HloTestBase; +std::unique_ptr MakeDot(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); +} + TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -40,8 +48,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, exp0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -59,8 +67,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -80,8 +88,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, bitcast0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -102,8 +110,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* reshape0 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1024, 256}), exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, reshape0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -121,8 +129,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 32 * 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -140,8 +148,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -162,8 +170,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, transpose1)); + builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -188,7 +196,9 @@ class OpcodeFusionTest : public InstructionFusionTest { // Runs CPU instruction fusion on the given module, and tests that the result // contains a fused op at the root with exactly the given multiset of opcodes. void RunFusionAndCheckOpcodesWereFused( - HloModule* module, const std::multiset& expected_opcodes) { + HloModule* module, const std::multiset& expected_opcodes, + HloInstruction::FusionKind fusion_kind = + HloInstruction::FusionKind::kLoop) { auto computation = module->entry_computation(); auto did_fusion = CpuInstructionFusion().Run(module); ASSERT_TRUE(did_fusion.ok()); @@ -196,7 +206,7 @@ class OpcodeFusionTest : public InstructionFusionTest { HloInstruction* root = computation->root_instruction(); ASSERT_THAT(root, op::Fusion()); - EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); + EXPECT_EQ(root->fusion_kind(), fusion_kind); std::vector fused_opcodes(root->fused_instruction_count()); std::transform(root->fused_instructions().begin(), @@ -608,6 +618,88 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { Not(op::Fusion())); } +void CreateComputationForDotAddOutputFusionTest(const string& test_name, + HloModule* module, int m, int k, + int n, + bool add_extra_use_for_dot) { + HloComputation::Builder builder(test_name); + + Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + + auto* dot_lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); + auto* dot_rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, dot_rhs_shape, "param1")); + auto* addend = builder.AddInstruction( + HloInstruction::CreateParameter(2, dot_shape, "param2")); + + auto* dot = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + builder.AddInstruction( + HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); + + if (add_extra_use_for_dot) { + builder.AddInstruction( + HloInstruction::CreateOutfeed(dot_shape, dot, "no_config")); + } + + module->AddEntryComputation(builder.Build()); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, + /*k=*/50, /*n=*/19, + /*add_extra_use_for_dot=*/false); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}, + HloInstruction::FusionKind::kOutput); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/1, + /*add_extra_use_for_dot=*/false); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}, + HloInstruction::FusionKind::kOutput); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/19, + /*add_extra_use_for_dot=*/false); + + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/1, + /*add_extra_use_for_dot=*/true); + + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc similarity index 54% rename from tensorflow/compiler/xla/service/cpu/layout_assignment.cc rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index c446b6b792a042da2500ea6a175fdca4c70bcab6..e8117377e61a4e21b8c45b929c518a18878fcb60 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -13,69 +13,89 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace cpu { -Status CpuLayoutAssignment::AddBackendConstraints( - LayoutConstraints* constraints) { - auto row_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - auto col_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.begin(), dimension_order.end(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - - // We want to change the layout of constant arrays to be column major when all - // of their users are dot operations that can be made faster with the flipped - // layout. To avoid going quadriatic over the # of instructions, we cache - // this property in should_make_rhs_col_major -- it maps a constant to true if - // all of the users of said constant are dot operations that can be sped up. - // This cache is populated lazily as we encounter dot operations traversing - // the instruction stream. - tensorflow::gtl::FlatMap - should_make_rhs_col_major_cache; - auto should_make_rhs_col_major = [&](const HloInstruction& instruction) { - if (ProfitableToImplementDotInLlvmIr(instruction) != - DotInLlvmIrProfitable::kWithColumnMajorRhs) { +// We want to change the layout of constant arrays to be column major when all +// of their users are dot operations that can be made faster with the flipped +// layout. To avoid going quadriatic over the # of instructions, we cache this +// property in should_make_rhs_col_major -- it maps a constant to true if all of +// the users of said constant are dot operations that can be sped up. This +// cache is populated lazily as we encounter dot operations traversing the +// instruction stream. + +namespace { +using ::tensorflow::gtl::nullopt; +using ::tensorflow::gtl::optional; + +using ShouldMakeOperandColMajorCache = + tensorflow::gtl::FlatMap; +} // namespace + +static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { + for (auto* user : instruction->users()) { + optional operand_idx = ProfitableToMakeDotOperandColumnMajor(*user); + if (!operand_idx || user->operand(*operand_idx) != instruction || + std::count(user->operands().begin(), user->operands().end(), + instruction) != 1) { return false; } + } + return true; +} - const auto* rhs = instruction.operand(1); - if (rhs->opcode() != HloOpcode::kConstant) { - return false; - } +static optional ShouldMakeOperandColumnMajor( + ShouldMakeOperandColMajorCache* cache, const HloInstruction& instruction) { + optional operand_idx = + ProfitableToMakeDotOperandColumnMajor(instruction); + if (!operand_idx) { + return nullopt; + } - auto it = should_make_rhs_col_major_cache.find(rhs); - if (it != should_make_rhs_col_major_cache.end()) { - return it->second; - } + const HloInstruction* operand = instruction.operand(*operand_idx); + if (operand->opcode() != HloOpcode::kConstant) { + return nullopt; + } - bool result = std::all_of( - rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) { - return ProfitableToImplementDotInLlvmIr(*user) == - DotInLlvmIrProfitable::kWithColumnMajorRhs && - user->operand(0) != rhs; - }); + auto it = cache->find(operand); + if (it == cache->end()) { + auto insert_result = + cache->insert({operand, ShouldMakeAllUsersColMajor(operand)}); + CHECK(insert_result.second); + it = insert_result.first; + } - InsertOrDie(&should_make_rhs_col_major_cache, rhs, result); - return result; - }; + return it->second ? operand_idx : nullopt; +} + +static Shape RowMajorShape(const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; +} + +static Shape ColMajorShape(const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.begin(), dimension_order.end(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; +} + +Status CpuLayoutAssignment::AddBackendConstraints( + LayoutConstraints* constraints) { + ShouldMakeOperandColMajorCache cache; const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { @@ -90,9 +110,9 @@ Status CpuLayoutAssignment::AddBackendConstraints( // // These constraints are not hard constraints. Ideally, we should decide // which layouts to choose according to some cost model. - Shape output_shape(row_major_shape(convolution->shape())); - Shape input_shape(row_major_shape(lhs_instruction->shape())); - Shape filter_shape(row_major_shape(rhs_instruction->shape())); + Shape output_shape(RowMajorShape(convolution->shape())); + Shape input_shape(RowMajorShape(lhs_instruction->shape())); + Shape filter_shape(RowMajorShape(rhs_instruction->shape())); // Set layouts of the instructions' shapes. TF_RETURN_IF_ERROR( @@ -101,11 +121,11 @@ Status CpuLayoutAssignment::AddBackendConstraints( constraints->SetOperandLayout(filter_shape, convolution, 1)); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, convolution)); - } else if (should_make_rhs_col_major(*instruction)) { - auto* dot = instruction; - const auto& rhs_shape = dot->operand(1)->shape(); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(col_major_shape(rhs_shape), dot, 1)); + } else if (optional op_idx = + ShouldMakeOperandColumnMajor(&cache, *instruction)) { + const HloInstruction* op = instruction->operand(*op_idx); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + ColMajorShape(op->shape()), instruction, *op_idx)); } else if (PotentiallyImplementedAsEigenDot(*instruction)) { const HloInstruction* dot = instruction; // In order to implement `dot` with Eigen dot, the layouts of the lhs, @@ -113,17 +133,17 @@ Status CpuLayoutAssignment::AddBackendConstraints( // // These constraints are not hard constraints. Ideally, we should decide // which layouts to choose according to some cost model. - Shape output_shape(row_major_shape(dot->shape())); + Shape output_shape(RowMajorShape(dot->shape())); const HloInstruction* lhs_instruction = dot->operand(0); - Shape lhs_shape(row_major_shape(lhs_instruction->shape())); + Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); // dot is a kDot or a kTransposeDot fusion node. In the latter case, if // it represents X @ X, it may have just one operand. if (dot->operand_count() > 1) { const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(row_major_shape(rhs_instruction->shape())); + Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); } @@ -140,8 +160,12 @@ Status CpuLayoutAssignment::AddBackendConstraints( if (constraints->OperandBufferForwarded(instruction, operand_no)) { continue; } + // Skip operands with non-array shapes. + if (!ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + continue; + } Shape operand_shape( - row_major_shape(instruction->operand(operand_no)->shape())); + RowMajorShape(instruction->operand(operand_no)->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( operand_shape, instruction, operand_no)); } diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h similarity index 86% rename from tensorflow/compiler/xla/service/cpu/layout_assignment.h rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 4fd8d68dd6b4f2a8b16f6c048743a996ea76a560..c8edbb9e15a5b6f9c574f5fe9d130d149499ebd2 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -38,4 +38,4 @@ class CpuLayoutAssignment : public LayoutAssignment { } // namespace cpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc similarity index 54% rename from tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 1ea5e8c7fc4896512e62396d0a756cda44785f11..6ba030fff3bbc5f413bfb133114ceb5309b77672 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include #include @@ -40,6 +40,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -61,8 +63,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -98,10 +100,10 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { HloInstruction::CreateParameter(1, lhs_shape, "param1")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs)); - auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs)); + auto dot_a_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); + auto dot_b_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); @@ -142,10 +144,10 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { HloInstruction::CreateParameter(1, lhs_b_shape, "param1")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_a_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs)); - auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_b_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs)); + auto dot_a_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); + auto dot_b_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); @@ -180,8 +182,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { HloInstruction::CreateConstant(Literal::CreateFromShape(lhs_shape))); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateParameter(0, rhs_shape, "param0")); - auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -220,8 +222,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); - auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -241,5 +243,172 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); } } + +struct DotOutputFusionLayoutAssignmentResult { + bool layout_assignment_changed_something; + const HloInstruction* dot_lhs_fusion_param; + const HloInstruction* dot_rhs_fusion_param; + const HloInstruction* addend_fusion_param; +}; + +static StatusOr RunDotOutputFusion( + HloModule* module, const string& test_name, int m, int k, int n, + const int64 dot_operand_idx_in_add) { + DotOutputFusionLayoutAssignmentResult result; + + CHECK(dot_operand_idx_in_add == 0 || dot_operand_idx_in_add == 1); + + auto builder = HloComputation::Builder(test_name); + + Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + + HloInstruction* dot_lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); + HloInstruction* addend = builder.AddInstruction( + HloInstruction::CreateParameter(1, dot_shape, "param1")); + HloInstruction* dot_rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); + HloInstruction* dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + HloInstruction* add_result; + if (dot_operand_idx_in_add == 0) { + add_result = builder.AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, dot_result, addend)); + } else { + add_result = builder.AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, addend, dot_result)); + } + + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloInstruction* fusion_instruction = + module->entry_computation()->AddInstruction(HloInstruction::CreateFusion( + dot_shape, HloInstruction::FusionKind::kOutput, add_result)); + TF_RETURN_IF_ERROR( + computation->ReplaceInstruction(add_result, fusion_instruction)); + + HloInstruction* fused_add = + fusion_instruction->fused_instructions_computation()->root_instruction(); + HloInstruction* fused_dot = fusion_instruction->FuseInstruction(dot_result); + + TF_RETURN_IF_ERROR( + computation->RemoveInstructionAndUnusedOperands(dot_result)); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_lhs_shape)); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); + *computation_layout.mutable_result_layout() = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); + + result.dot_lhs_fusion_param = + fusion_instruction->operand(fused_dot->operand(0)->parameter_number()); + result.dot_rhs_fusion_param = + fusion_instruction->operand(fused_dot->operand(1)->parameter_number()); + result.addend_fusion_param = fusion_instruction->operand( + fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); + + cpu::CpuLayoutAssignment layout_assignment(&computation_layout); + TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, + layout_assignment.Run(module)); + + return result; +} + +static void AssertCorrectLayoutForDotOutputFusion( + const HloComputation* computation, + const DotOutputFusionLayoutAssignmentResult& layout_assignment_result, + bool expect_col_major_dot_rhs) { + Layout expected_dot_rhs_layout = expect_col_major_dot_rhs + ? LayoutUtil::MakeLayout({0, 1}) + : LayoutUtil::MakeLayout({1, 0}); + EXPECT_TRUE(LayoutUtil::Equal( + expected_dot_rhs_layout, + layout_assignment_result.dot_rhs_fusion_param->shape().layout())); + + EXPECT_TRUE(LayoutUtil::Equal( + LayoutUtil::MakeLayout({1, 0}), + layout_assignment_result.dot_lhs_fusion_param->shape().layout())); + + EXPECT_TRUE(LayoutUtil::Equal( + LayoutUtil::MakeLayout({1, 0}), + layout_assignment_result.addend_fusion_param->shape().layout())); + EXPECT_THAT(computation->instructions(), Each(Not(op::Copy()))); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/true); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/true); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index dba140d1120bc5502d2039e1663b9bf035d8d66a..09f028463af68bbc2841fecdb2ca6c6a42498798 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -15,11 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "tensorflow/core/lib/strings/numbers.h" + namespace { const char* const kXlaParallelCpuOption = "xla_cpu_parallel"; const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; +const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; } // namespace @@ -45,6 +48,19 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) { return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0; } +tensorflow::gtl::optional LlvmIrGemvTilingFactor( + const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + auto it = extra_options_map.find(kLlvmIrDotTilingFactor); + int64 tiling_factor; + if (it != extra_options_map.end() && + tensorflow::strings::safe_strto64(it->second, &tiling_factor)) { + return tiling_factor; + } + return tensorflow::gtl::nullopt; +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 5dc24ebc7b8661092e3bc27c4f30fda1e497e41b..6ba0fd24538b63a3da81083482e6bee3b552dfea 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -27,6 +27,8 @@ namespace options { bool CpuParallelBackendRequested(const HloModuleConfig& config); bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); +tensorflow::gtl::optional LlvmIrGemvTilingFactor( + const HloModuleConfig& config); } // namespace options } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index f8e260dd90149405fff7beefba3f7fe83b75d4b6..f385829cdf5cafbd35e083f47106734cdd5dde88 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.h b/tensorflow/compiler/xla/service/cpu/disassembler.h index b6feaa7e45cee26eb7f850081bd1fad2cb63b15c..5e302f88990ee4a3c37758881ecec4d6f71dd8e6 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.h +++ b/tensorflow/compiler/xla/service/cpu/disassembler.h @@ -37,7 +37,7 @@ struct DisassemblerResult { DisassemblerResult(const string& text, size_t code_size_bytes) : text(text), code_size_bytes(code_size_bytes) {} - // The dissassembled text sections of the object file. + // The disassembled text sections of the object file. string text; // The total number of bytes of executable code in the object file. uint64_t code_size_bytes; @@ -53,7 +53,7 @@ class Disassembler { // Returns a DisassemblerResult for the given object file, containing the // disassembled code. // - // If we couldnt' retrieve a disassembler for this platform, an error status + // If we couldn't retrieve a disassembler for this platform, an error status // is returned. StatusOr DisassembleObjectFile( const llvm::object::ObjectFile& object_file) const; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index e57d49172b18beb75cfbb482c5d732ef679ebe41..0631454d5c1fde15f32f19285094cac6d7ff298c 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,9 +23,10 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -38,20 +39,501 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { -DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, - bool transpose_rhs, - const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config) +namespace { +// Loads a tile of values from a 2D tensor. +class TileLoader { + public: + // Constructs a TileLoader that will load a tile consisting of + // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at + // `major_dim_offset` in the major dimension. The tile size along the minor + // dimension is the vector size, and that is implicitly determined by `vsl`. + TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, + llvm::Value* matrix, int64 matrix_size_along_minor_dim, + llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) + : vsl_(vsl) { + pointers_.reserve(tile_size_along_major_dim); + for (int64 i = 0; i < tile_size_along_major_dim; i++) { + llvm::Value* total_offset = ir_builder->CreateMul( + ir_builder->getInt64(matrix_size_along_minor_dim), + ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset)); + pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); + } + } + + // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at + // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the + // minor dimension. + std::vector LoadTile(llvm::Value* minor_dim_offset) const { + std::vector result; + result.reserve(pointers_.size()); + for (const auto& pointer : pointers_) { + result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); + } + return result; + } + + private: + VectorSupportLibrary* vsl_; + std::vector pointers_; +}; + +// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +--+--+--+--+ +// |M00|M10|M20|M30| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M03|M13|M23|M33| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// +// (Legend: rows are horizontal and columns are vertical; and each column is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is from the column major left matrix. +// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] +// vector loaded from the RHS vector. +// +// As we iterate through the column dimension, we compute the change to the +// result vector by an elementwise multiplication between the two tiles above +// followed by a reduction along the major dimension: +// +// +-----------------------------------+ +// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | +// +-----------------------------------+ +// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | +// Result[R:R+4] += +-----------------------------------+ +// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | +// +-----------------------------------+ +// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | +// +-----------------------------------+ +// +// Where R is the starting row for the tile. +// +// We have an inner epilogue loop to deal with the "C" submatrix and an outer +// epilogue loop to deal with the B,D submarix. +// +// TODO(sanjoy): We should investigate if using gather loads and scatter stores +// can be used here have the same inner loop for both column-major and row-major +// matrix-vector products. +class ColumnMajorMatrixVectorProductEmitter { + public: + ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, + int64 tile_rows, int64 tile_cols, + int64 m, int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, + llvm::IRBuilder<>* ir_builder) + : scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + lhs_(lhs), + rhs_(rhs), + addend_(addend), + result_(result), + ir_builder_(ir_builder), + ksl_(ir_builder_), + vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") { + CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast(tile_rows_))); + } + + void Emit(); + + private: + void EmitOuterLoopBody(llvm::Value* column, int64 column_count, + bool is_first_column); + + TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) { + return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m_, + /*major_dim_offset=*/column_start, + /*tile_size_along_major_dim=*/column_count); + } + + // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous + // sequnce of `count` values, each one broadcasted to the vector width. + std::vector LoadRhsTile(llvm::Value* offset, int64 count) { + llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); + std::vector result; + result.reserve(count); + for (int64 i = 0; i < count; i++) { + result.push_back(vsl_.LoadBroadcast(base_pointer, i)); + } + return result; + } + + void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, + const std::vector& rhs_tile, + int64 columns, bool is_first_column); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, + bool is_first_tiled_column); + + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* addend_; + llvm::Value* result_; + llvm::IRBuilder<>* ir_builder_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( + llvm::Value* column, int64 column_count, bool is_first_column) { + TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column, + /*column_count=*/column_count); + + std::vector rhs_tile = + LoadRhsTile(column, /*count=*/column_count); + EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile, + /*columns=*/column_count, is_first_column); + EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); +} + +void ColumnMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 column_remainder = k_ % tile_cols_; + int64 column_limit = k_ - column_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_, + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols_, is_first_column); + }); + + if (column_remainder != 0) { + EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, + column_limit == 0); + } +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + TileLoader* lhs_tile_loader, const std::vector& rhs_tile, + int64 columns, bool is_first_column) { + int64 row_limit = m_ - (m_ % tile_rows_); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows_, [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { + int64 row_start = m_ - (m_ % tile_rows_); + if (row_start == m_) { + return; + } + + llvm::Value* columns_llvm = ir_builder_->getInt64(columns); + + // for (col = current_tile_col; col < (columns + current_tile_col); col++) + // for (row = row_start, row < m_; row++) { + // result[row] += lhs[row, col] * rhs[col] + // // Also take into account that if col is 0 then result[row] is not + // // initialized. + // } + + ksl_.For( + "dot.inner.epilg.outer", /*start=*/current_tile_col, + /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), + /*step=*/1, /*peel_first_iteration=*/false, + [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { + llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); + llvm::Value* total_offset = + ir_builder_->CreateMul(col, ir_builder_->getInt64(m_)); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For( + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_, + /*step=*/1, [&](llvm::Value* scalar_row) { + llvm::Value* product = vsl_.Mul( + vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); + llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( + is_first_scalar_col, + ir_builder_->getInt1(is_first_tiled_column)); + ksl_.If( + setting_result_first_time, + /*true_block_generator=*/ + [&]() { + if (addend_) { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), + product), + result_, scalar_row); + } else { + vsl_.StoreScalar(product, result_, scalar_row); + } + }, + /*false_block_generator=*/ + [&]() { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), + result_, scalar_row); + }); + }); + }); +} + +// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +// |M00|M10|M20|M30| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| +// +---+---+---+---+ +// |M03|M13|M23|M33| +// +---+---+---+---+ +// +// (Legend: rows are horizontal and columns are vertical; and each row is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is loaded from the row major left matrix. +// b. The right vector is loaded from the RHS vector. +// +// We keep 4 vector accumulators accumulating the following four vector +// expressions as we iterate over the row dimension: +// +// +------+------+------+------+ +// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) +// +------+------+------+------+ +// +// In the end we do a horizontal reduction over these 4 vector accumulators to +// get 4 values in the result vector. +// +// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer +// epilogue loop to deal with the C,D submatrix. +class RowMajorMatrixVectorProductEmitter { + public: + RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, + llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, llvm::Value* result, + llvm::IRBuilder<>* ir_builder) + : scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + lhs_(lhs), + rhs_(rhs), + addend_(addend), + result_(result), + ir_builder_(ir_builder), + ksl_(ir_builder_), + vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") { + CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast(tile_cols_))); + } + + void Emit(); + + private: + TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) { + return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k_, + /*major_dim_offset=*/row_start, + /*tile_size_along_major_dim=*/row_count); + } + + void EmitOuterLoopBody(llvm::Value* row, int64 row_count); + + void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows, + std::vector* vector_accumulators); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators); + + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* addend_; + llvm::Value* result_; + llvm::IRBuilder<>* ir_builder_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, + int64 row_count) { + TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row, + /*row_count=*/row_count); + std::vector vector_accumulators; + std::vector scalar_accumulators; + for (int i = 0; i < row_count; i++) { + vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); + scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); + } + EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count, + &vector_accumulators); + EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, + &scalar_accumulators); + + std::vector accumulator_values; + std::transform( + vector_accumulators.begin(), vector_accumulators.end(), + std::back_inserter(accumulator_values), + [](const VectorVariable& vector_var) { return vector_var.Get(); }); + + std::vector horizontal_sums; + if (row_count == vsl_.vector_size()) { + if (addend_) { + horizontal_sums = vsl_.ComputeHorizontalSums( + std::move(accumulator_values), vsl_.LoadVector(addend_, row)); + } else { + horizontal_sums = + vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + } else { + horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + + for (int i = 0; i < row_count; i++) { + llvm::Value* result_value = + vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); + llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row); + if (addend_ && row_count != vsl_.vector_size()) { + result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); + } + vsl_.StoreScalar(result_value, result_, offset); + } +} + +void RowMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 row_remainder = m_ % tile_rows_; + int64 row_limit = m_ - row_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_, + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); }); + + if (row_remainder != 0) { + EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); + } +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + TileLoader* lhs_tile_loader, int64 rows, + std::vector* vector_accumulators) { + int64 column_limit = k_ - (k_ % tile_cols_); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols_, [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set( + vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators) { + int64 column_start = k_ - (k_ % tile_cols_); + if (column_start == k_) { + return; + } + + for (int r = 0; r < rows; r++) { + llvm::Value* total_offset = ir_builder_->CreateMul( + ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), + ir_builder_->getInt64(k_)); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_, + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); + } +} + +} // namespace + +DotOpEmitter::DotOpEmitter( + const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, + const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config) : dot_(dot), transpose_lhs_(transpose_lhs), transpose_rhs_(transpose_rhs), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), + addend_array_(addend_array), executable_run_options_value_(executable_run_options_value), ir_builder_(ir_builder), hlo_module_config_(hlo_module_config) {} @@ -59,19 +541,140 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, /* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F32 == type || F64 == type || C64 == type); DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, - lhs_array, rhs_array, executable_run_options_value, - ir_builder, hlo_module_config); + lhs_array, rhs_array, addend_array, + executable_run_options_value, ir_builder, + hlo_module_config); return dot_emitter.Emit(); } bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } +bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { + if (dot_.shape().dimensions_size() != 2) { + return false; + } + + PrimitiveType primitive_type = dot_.shape().element_type(); + + if (!primitive_util::IsFloatingPointType(primitive_type) && + !primitive_util::IsIntegralType(primitive_type)) { + return false; + } + + MatMultDims mat_mult_dims = GetMatMultDims(); + bool is_column_major_matrix_vector = false; + bool is_row_major_matrix_vector = false; + + int64 m, k; + bool swap_operands; + + if (mat_mult_dims.m == 1) { + bool rhs_effectively_row_major = + transpose_rhs_ ^ !mat_mult_dims.rhs_column_major; + if (rhs_effectively_row_major) { + k = mat_mult_dims.k; + m = mat_mult_dims.n; + is_column_major_matrix_vector = true; + swap_operands = true; + } else { + k = mat_mult_dims.k; + m = mat_mult_dims.n; + is_row_major_matrix_vector = true; + swap_operands = true; + } + } + + if (mat_mult_dims.n == 1) { + bool lhs_effectively_column_major = + transpose_lhs_ ^ mat_mult_dims.lhs_column_major; + if (lhs_effectively_column_major) { + m = mat_mult_dims.m; + k = mat_mult_dims.k; + is_column_major_matrix_vector = true; + swap_operands = false; + } else { + m = mat_mult_dims.m; + k = mat_mult_dims.k; + is_row_major_matrix_vector = true; + swap_operands = false; + } + } + + if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { + return false; + } + + int64 tiling_factor = GetGemvTilingFactor(); + CHECK_GT(tiling_factor, 0); + + llvm::Value* result_op = target_array_.GetBasePointer(); + llvm::Value* lhs_op = + swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer(); + llvm::Value* rhs_op = + swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); + + const bool enable_fast_math = + hlo_module_config_.debug_options().xla_enable_fast_math(); + const bool optimize_for_size = + options::OptimizeForSizeRequested(hlo_module_config_); + + if (is_column_major_matrix_vector) { + VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m + << " and k = " << k; + int64 tile_rows = 8; + int64 tile_cols = tiling_factor; + + string kernel_name = tensorflow::strings::StrCat( + "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, + lhs_op, rhs_op, + addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, + llvm::Value* result_op) { + ColumnMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + addend_op, result_op, ir_builder_); + emitter.Emit(); + }); + } else { + VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m + << " and k = " << k; + int64 tile_rows = tiling_factor; + int64 tile_cols = 8; + + string kernel_name = tensorflow::strings::StrCat( + "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, + lhs_op, rhs_op, + addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, + llvm::Value* result_op) { + RowMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + addend_op, result_op, ir_builder_); + emitter.Emit(); + }); + } + + return true; +} + tensorflow::Status DotOpEmitter::Emit() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. @@ -105,6 +708,12 @@ tensorflow::Status DotOpEmitter::Emit() { return EmitScalarDot(); } + if (EmitLlvmIrDotIfProfitable()) { + return Status::OK(); + } + + CHECK_EQ(addend_array_, nullptr); + if (PotentiallyImplementedAsEigenDot(dot_)) { return EmitCallToRuntime(); } @@ -340,22 +949,17 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'. - const Shape& lhs_shape = lhs_array_.GetShape(); - const Shape& rhs_shape = rhs_array_.GetShape(); + MatMultDims mat_mult_dims = GetMatMultDims(); - CHECK(LayoutUtil::Equal(lhs_shape.layout(), rhs_shape.layout())); + CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major); - int64 m = lhs_shape.dimensions(transpose_lhs_ ? 1 : 0); - int64 k = lhs_shape.dimensions(transpose_lhs_ ? 0 : 1); - int64 n = rhs_shape.dimensions(transpose_rhs_ ? 0 : 1); const llvm_ir::IrArray* lhs = &lhs_array_; const llvm_ir::IrArray* rhs = &rhs_array_; bool transpose_lhs = transpose_lhs_; bool transpose_rhs = transpose_rhs_; - bool is_column_major = lhs_shape.layout().minor_to_major(0) == 0; - if (!is_column_major) { - std::swap(m, n); + if (!mat_mult_dims.lhs_column_major) { + std::swap(mat_mult_dims.m, mat_mult_dims.n); std::swap(lhs, rhs); std::swap(transpose_lhs, transpose_rhs); } @@ -367,12 +971,27 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { float_ptr_type), ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type), ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type), - ir_builder_->getInt64(m), ir_builder_->getInt64(n), - ir_builder_->getInt64(k), ir_builder_->getInt32(transpose_lhs), + ir_builder_->getInt64(mat_mult_dims.m), + ir_builder_->getInt64(mat_mult_dims.n), + ir_builder_->getInt64(mat_mult_dims.k), + ir_builder_->getInt32(transpose_lhs), ir_builder_->getInt32(transpose_rhs)}); return tensorflow::Status::OK(); } +DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { + CHECK_EQ(dot_.shape().dimensions_size(), 2); + + const Shape& lhs_shape = lhs_array_.GetShape(); + const Shape& rhs_shape = rhs_array_.GetShape(); + + return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), + lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), + rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), + lhs_shape.layout().minor_to_major(0) == 0, + rhs_shape.layout().minor_to_major(0) == 0}; +} + llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix) { @@ -403,5 +1022,113 @@ llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( return index; } +// Return whether the given shape is a matrix with no padding. +static bool IsRank2WithNoPadding(const Shape& shape) { + return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); +} + +// In a gemm operation where output = lhs * rhs, check whether the given shapes +// are valid for the operation. +static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape) { + // The inputs and the output must + // 1) be matrices with no padding, and + // 2) have an allowed element type. + return output_shape.element_type() == F32 && + IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && + IsRank2WithNoPadding(output_shape); +} + +bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { + // For certain types of Dot, we can call Eigen + if (hlo.opcode() == HloOpcode::kDot) { + const Shape& lhs_shape = hlo.operand(0)->shape(); + const Shape& rhs_shape = hlo.operand(1)->shape(); + + if (ShapeUtil::HasZeroElements(lhs_shape) || + ShapeUtil::HasZeroElements(rhs_shape)) { + return false; + } + + if (ProfitableToImplementDotInTiledLlvmIr(hlo)) { + return false; + } + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + return true; + } + } + + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && + hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { + auto* dot = hlo.fused_expression_root(); + const Shape& lhs_shape = dot->operand(0)->shape(); + const Shape& rhs_shape = dot->operand(1)->shape(); + if (ShapeUtil::HasZeroElements(lhs_shape) || + ShapeUtil::HasZeroElements(rhs_shape)) { + return false; + } + return true; + } + + return false; +} + +// For vector-matrix dot products, it is always profitable to make the Rhs +// column major. +tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( + const HloInstruction& hlo) { + if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 && + hlo.shape().dimensions(0) == 1) { + if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) { + return 1; + } + return {}; + } + + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kOutput) { + auto* fusion_root = + hlo.fused_instructions_computation()->root_instruction(); + if (fusion_root->opcode() != HloOpcode::kAdd) { + return {}; + } + + for (auto* fusion_root_op : fusion_root->operands()) { + if (fusion_root_op->opcode() != HloOpcode::kDot) { + continue; + } + if (auto operand_num = + ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) { + auto* operand = fusion_root_op->operand(*operand_num); + if (operand->opcode() == HloOpcode::kParameter && + operand->user_count() == 1) { + return operand->parameter_number(); + } + } + } + } + + return {}; +} + +bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { + // Any Matrix-Vector product of floating point or integral type, or + // a transpose-dot fusion of the same can be lowered to a tiled LLVM + // IR implementation. + const Shape& shape = dot.shape(); + return shape.dimensions_size() == 2 && + (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) && + (primitive_util::IsFloatingPointType(shape.element_type()) || + primitive_util::IsIntegralType(shape.element_type())); +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index cfc10660453c822635d68270c053977fca779ee1..2118965a70872846204974e25555340baca718cf 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ #include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -29,6 +30,18 @@ limitations under the License. namespace xla { namespace cpu { +bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); + +// Returns the index for an operand to `hlo` that should ideally be column +// major. Returns nullopt if there is no such operand or if `hlo` is not a dot +// or a fusion containing a dot. +tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( + const HloInstruction& hlo); + +// Returns true to indicate that we can generate a tiled LLVM IR implementation +// for |dot|. +bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot); + // Helper class for emitting LLVM IR to perform the dot operation. class DotOpEmitter { public: @@ -36,10 +49,15 @@ class DotOpEmitter { // place the result in target_array. IR is emitted at current insert point of // the builder. Upon completion of the method, the insert point is set to the // end of all instructions emitted for this operation. + // + // If `addend_array` is not nullptr then it must be an array of the same + // dimensions as the result, and the result is computed as `addend_array` + + // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported + // for Matrix-vector products. static tensorflow::Status EmitDotOperation( const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config); @@ -48,6 +66,7 @@ class DotOpEmitter { bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config); @@ -59,6 +78,10 @@ class DotOpEmitter { // LHS and RHS) and store the results in the target. tensorflow::Status EmitScalarDot(); + // Emit an LLVM IR implementation of the dot operation if we can. Returns + // true if an LLVM IR implementation was emitted. + bool EmitLlvmIrDotIfProfitable(); + // Emits a call to the CPU runtime to perform the matrix multiply. tensorflow::Status EmitCallToRuntime(); @@ -77,12 +100,45 @@ class DotOpEmitter { // no padding, and a rank of two. bool ShapesAreLegalForRuntimeDot() const; + // Represents the dimensions of a matrix-matrix multiply operation. + struct MatMultDims { + // The number of rows in the LHS. + int64 m; + + // The number of columns in the LHS, which is also must be equal to the + // number of rows in the RHS. + int64 k; + + // The number of columns on the RHS. + int64 n; + + // True if the LHS matrix column major. + bool lhs_column_major; + + // True if the RHS matrix column major. + bool rhs_column_major; + }; + + // Get the MatMultDims instance for the dot product this DotOpEmitter + // represents. Precondition: the dot is of rank 2 (and thus its operands are + // of rank 2 as well). + MatMultDims GetMatMultDims() const; + + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector + // registers. + int64 GetGemvTilingFactor() const { + const int64 kDefaultTilingFactor = 8; + return options::LlvmIrGemvTilingFactor(hlo_module_config_) + .value_or(kDefaultTilingFactor); + } + const HloInstruction& dot_; const bool transpose_lhs_; const bool transpose_rhs_; const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; const llvm_ir::IrArray& rhs_array_; + const llvm_ir::IrArray* addend_array_; llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* ir_builder_; const HloModuleConfig& hlo_module_config_; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index b99b36a55eee40bc66dcb1b7b1a464bf764ef0ea..3993779da636e519f8d8fded468c3271d27ee093 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -29,10 +29,8 @@ bool PotentiallyImplementedAsEigenConvolution( // The following conditions are necessary (but not sufficient) for // implementing `convolution` with Eigen convolution: // - the input and kernel have a non-zero number of elements. - // - the input is in NHWC or NWHC order. - // - the kernel is in HWIO or WHIO order. - // - the spatial dimensions are in the same relative order in the input, - // kernel and output. + // - the input is in NHWC order. + // - the kernel is in HWIO order. // // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); @@ -51,15 +49,22 @@ bool PotentiallyImplementedAsEigenConvolution( convolution.convolution_dimension_numbers(); // Only 1D and 2D convolutions are supported at the moment. // TODO(b/32897908): add an optimized implementation for 3D convolution. - if (dnums.spatial_dimensions_size() > 2) { + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + if (num_spatial_dims > 2) { return false; } - bool input_spatial_dims_ascending = std::is_sorted( - dnums.spatial_dimensions().begin(), dnums.spatial_dimensions().end()); - bool kernel_spatial_dims_ascending = - std::is_sorted(dnums.kernel_spatial_dimensions().begin(), - dnums.kernel_spatial_dimensions().end()); + for (int64 i = 0; i < num_spatial_dims; ++i) { + if (dnums.input_spatial_dimensions(i) != i + 1) { + return false; + } + if (dnums.kernel_spatial_dimensions(i) != i) { + return false; + } + if (dnums.output_spatial_dimensions(i) != i + 1) { + return false; + } + } const Shape& output_shape = convolution.shape(); return dnums.input_batch_dimension() == 0 && @@ -67,116 +72,11 @@ bool PotentiallyImplementedAsEigenConvolution( dnums.output_batch_dimension() == 0 && dnums.output_feature_dimension() == output_shape.dimensions_size() - 1 && - input_spatial_dims_ascending == kernel_spatial_dims_ascending && dnums.kernel_input_feature_dimension() == kernel_shape.dimensions_size() - 2 && dnums.kernel_output_feature_dimension() == kernel_shape.dimensions_size() - 1; } -namespace { - -// Return whether the given shape is a matrix with no padding. -bool IsRank2WithNoPadding(const Shape& shape) { - return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); -} - -// In a gemm operation where output = lhs * rhs, check whether the given shapes -// are valid for the operation. -bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { - // The inputs and the output must - // 1) be matrices with no padding, and - // 2) have an allowed element type. - return output_shape.element_type() == F32 && - IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape); -} -} // namespace - -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { - // For certain types of Dot, we can call Eigen - if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - - if (ProfitableToImplementDotInLlvmIr(hlo) == DotInLlvmIrProfitable::kYes) { - return false; - } - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); - return true; - } - } - - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - auto* dot = hlo.fused_expression_root(); - const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - return true; - } - - return false; -} - -DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr( - const HloInstruction& dot) { - if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) { - const Shape& result_shape = dot.shape(); - // kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1 - // cache line size, so that we can have the reduction dimension of both the - // LHS and RHS matrices and still have some space "left over". This needs - // to be tuned further. - const int64 kReductionDimensionThresholdBytes = 8 * 1024; - const bool single_threaded_eigen = - !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen(); - - // This is the point at which it is better to call into Eigen and shard the - // dot across multiple worker threads. This is a rough estimate by running - // a matmult benchmark on my local machine, and it can be tuned further. - const int64 kMaxSingleThreadedFlops = 16 * 1024; - - const int64 M = result_shape.dimensions(0); - const int64 N = result_shape.dimensions(1); - const int64 K = dot.operand(1)->shape().dimensions(0); - const int64 primitive_type_size = - ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type()); - if (M == 1 && - K * primitive_type_size <= kReductionDimensionThresholdBytes && - (single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) { - // Heuristics: - // - // - Look for a configuration where we will likely be able to keep LHS in - // L1 and do a cache-optimal traversal of RHS. - // - // - Bail out on matrices that are large enough that Eigen can profitably - // shard the computation across multiple cores. This only applies when - // multi-threading is enabled. - return LayoutUtil::IsMonotonicWithDim0Major( - dot.operand(1)->shape().layout()) - ? DotInLlvmIrProfitable::kWithColumnMajorRhs - : DotInLlvmIrProfitable::kYes; - } - } - return DotInLlvmIrProfitable::kNo; -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index 66656ed99765806ec4463f3781644853886cf303..34b2003916933f5ec0a15d9e219063c0a912fa40 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ +#include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -24,20 +25,17 @@ namespace cpu { bool PotentiallyImplementedAsEigenConvolution( const HloInstruction& convolution); -bool PotentiallyImplementedAsEigenDot(const HloInstruction& dot); - -enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs }; - -// Returns a value to indicate if (and under what conditions) will lowering -// |dot| as a pure LLVM IR dot operation be profitable over calling into Eigen. -// Possible return values are: +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] // -// * DotInLlvmIrProfitable::kYes - always profitable. -// * DotInLlvmIrProfitable::kNo - never profitable. -// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make -// the Rhs layout column major. -DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr( - const HloInstruction& dot); +// See IrFunction and ParallelLoopEmitter for details. +using DynamicLoopBounds = std::vector>; } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index a20ce6826ca0a86f8c0d441c1e89f091cfb434f1..a15baf7a4b1ca63841a696f81c581a90028fecc2 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,16 +24,17 @@ limitations under the License. #include #include +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/Target/TargetRegisterInfo.h" -#include "llvm/Target/TargetSubtargetInfo.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -42,6 +43,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -76,14 +79,16 @@ namespace cpu { IrEmitter::IrEmitter( const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, - const std::unordered_map* hlo_to_profile_idx, + std::unordered_map hlo_to_profile_idx, + tensorflow::gtl::optional entry_computation_profile_idx, llvm::TargetMachine* target_machine, ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), ir_builder_(llvm_module->getContext()), - hlo_to_profile_idx_(hlo_to_profile_idx), + hlo_to_profile_idx_(std::move(hlo_to_profile_idx)), + entry_computation_profile_idx_(std::move(entry_computation_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), parallel_cpu_backend_( @@ -122,133 +127,27 @@ StatusOr IrEmitter::EmitComputation( } else { TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); } - InsertOrDie(&emitted_functions_, computation, compute_function_); - - return compute_function_; -} - -static llvm::Argument* GetArg(llvm::Function* f, int idx) { - llvm::Function::arg_iterator arg_iter = f->arg_begin(); - std::advance(arg_iter, idx); - return &*arg_iter; + llvm::Function* ir_function = compute_function_->function(); + InsertOrDie(&emitted_functions_, computation, ir_function); + // Delete 'compute_function', finalizing 'ir_function' and restoring caller + // IR insert point. + compute_function_.reset(); + return ir_function; } void IrEmitter::InitializeIrFunction(const string& function_name) { - // The function signature is: - // void function(i8* retval, i8* run_options, i8** params, i8** temps, - // i64* dynamic_loop_bounds, i64* prof_counters) - // - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: address of an array with pointers to temporary buffers. - // - // Therefore, the generated function's signature (FunctionType) is statically - // determined - parameter unpacking is done in code generated into the - // function, rather than by a prologue dictated by the platform ABI. - // - // /--------------\ - // retval ----------> | return value | - // \--------------/ - // - // /-------------------------------\ - // run_options -----> | xla::ExecutableRunOptions | - // \-------------------------------/ - // - // /---------------------------------------------\ - // params --------> | param 0 | param 1 | ..... | param N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | param 0 | | param 1 | | param N-1 | - // \---------/ \---------/ \-----------/ - // - // /---------------------------------------------\ - // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | temp 0 | | temp 1 | | temp N-1 | - // \---------/ \---------/ \-----------/ - // - // /--------------------------------------------\ - // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| - // (elided for aot) \--------------------------------------------/ - // - // /---------------------------------------------\ - // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | - // (elided for aot) \---------------------------------------------/ - - // Even though the type of params and temps is void** in the host's view, in - // LLVM IR this is represented by i8*, similarly to void*. It's up to the code - // to use GEPs to unravel the indirection layers. - llvm::FunctionType* compute_function_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/GetComputeFunctionParams(), - /*isVarArg=*/false); - // Functions with local linkage get an inlining bonus. Because we know // a-priori that embedded functions (non-entry functions) will not have its // name resolved, give it local linkage. llvm::Function::LinkageTypes linkage = is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; - compute_function_ = - llvm::Function::Create(/*Ty=*/compute_function_type, - /*Linkage=*/linkage, - /*Name=*/AsStringRef(function_name), - /*Module=*/module_); - compute_function_->setCallingConv(llvm::CallingConv::C); - - // Set meaningful names for the function's arguments: useful for debugging. - llvm::Function::arg_iterator arg_iter = compute_function_->arg_begin(); - arg_iter->setName("retval"); - (++arg_iter)->setName("run_options"); - (++arg_iter)->setName("params"); - (++arg_iter)->setName("temps"); - if (num_dynamic_loop_bounds_ > 0) { - (++arg_iter)->setName("dynamic_loop_bounds"); - } - if (hlo_to_profile_idx_) { - (++arg_iter)->setName("prof_counters"); - } - - // We know a-priori that the function arguments are guaranteed to point to - // disjoint objects. - llvm::Argument* retval = GetResultArgument(); - for (llvm::Argument& argument : compute_function_->args()) { - // However, the return buffer aliases the temporaries and thus cannot be - // marked noalias. - if (&argument == retval) { - continue; - } - compute_function_->addAttribute(argument.getArgNo() + 1, - llvm::Attribute::NoAlias); - } - - // Add the optize attribute to the function if optimizing for size. This - // controls internal behavior of some optimization passes (e.g. loop - // unrolling). - if (options::OptimizeForSizeRequested(hlo_module_config_)) { - compute_function_->addFnAttr(llvm::Attribute::OptimizeForSize); - } - - if (hlo_module_config_.debug_options().xla_enable_fast_math()) { - compute_function_->addFnAttr("unsafe-fp-math", "true"); - compute_function_->addFnAttr("no-infs-fp-math", "true"); - compute_function_->addFnAttr("no-nans-fp-math", "true"); - compute_function_->addFnAttr("no-signed-zeros-fp-math", "true"); - } - - ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( - /*Context=*/module_->getContext(), - /*Name=*/"entry", - /*Parent=*/compute_function_)); + // Create and initialize new IrFunction. + compute_function_.reset( + new IrFunction(function_name, linkage, + options::OptimizeForSizeRequested(hlo_module_config_), + hlo_module_config_.debug_options().xla_enable_fast_math(), + module_, &ir_builder_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -344,11 +243,12 @@ int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { // Calculate the alignment of a buffer allocated for a given primitive type. int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { - int64 buffer_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - DCHECK_GE(buffer_size, 0); - DCHECK_LE(buffer_size, SIZE_MAX); - - return MinimumAlignmentForBufferSize(buffer_size); + int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + DCHECK_GE(byte_size, 0); + // Largest scalar is a complex64 so we don't need to worry about the + // int64->int truncation here. + DCHECK_LE(byte_size, 8); + return byte_size; } int64 IrEmitter::ByteSizeOf(const Shape& shape) const { @@ -357,6 +257,10 @@ int64 IrEmitter::ByteSizeOf(const Shape& shape) const { // Calculate the alignment of a buffer allocated for a given shape. int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { + if (ShapeUtil::IsScalar(shape)) { + return MinimumAlignmentForPrimitiveType(shape.element_type()); + } + int64 buffer_size = ByteSizeOf(shape); DCHECK_GE(buffer_size, 0); DCHECK_LE(buffer_size, SIZE_MAX); @@ -612,7 +516,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32})); + /*supported_types=*/{F32, BF16})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { @@ -795,7 +699,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. llvm_ir::IrArray::Index operand_index(source_index.size()); - llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); + llvm::Value* in_bounds_condition = ir_builder_.getTrue(); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); @@ -822,14 +726,16 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); - const auto save_operand_index = [&]( - const llvm_ir::IrArray::Index& operand_index) { - for (int64 i = 0; i < rank; ++i) { - llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( - selected_index_address, {ir_builder_.getInt32(i)}); - ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); - } - }; + const auto save_operand_index = + [&](const llvm_ir::IrArray::Index& operand_index) { + for (int64 i = 0; i < rank; ++i) { + llvm::Value* selected_index_address_slot = + ir_builder_.CreateInBoundsGEP(selected_index_address, + {ir_builder_.getInt32(i)}); + ir_builder_.CreateStore(operand_index[i], + selected_index_address_slot); + } + }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); @@ -896,6 +802,24 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, /*supported_types=*/{F32, F64, C64})); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } + + if (dnums.lhs_contracting_dimensions_size() != 1) { + // This is disallowed by ShapeInference today. + return Unimplemented( + "Dot with multiple contracting dimensions not implemented."); + } + + if (dnums.lhs_contracting_dimensions(0) != + std::min(lhs->shape().dimensions_size() - 1, 1) || + dnums.rhs_contracting_dimensions(0) != 0) { + return Unimplemented( + "Dot with non-standard contracting dimensions not implemented."); + } llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -914,8 +838,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return DotOpEmitter::EmitDotOperation( *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_); + lhs_array, rhs_array, /*addend_array=*/nullptr, + GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_); } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { @@ -952,11 +876,12 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // Input tensor. const Shape& input_shape = convolution->operand(0)->shape(); int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension()); - int64 input_rows = input_shape.dimensions(dnums.spatial_dimensions(0)); + int64 input_rows = + input_shape.dimensions(dnums.input_spatial_dimensions(0)); int64 input_cols = one_dim_convolution ? 1 - : input_shape.dimensions(dnums.spatial_dimensions(1)); + : input_shape.dimensions(dnums.input_spatial_dimensions(1)); int64 input_channels = input_shape.dimensions(dnums.input_feature_dimension()); @@ -976,11 +901,11 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // Output tensor. const Shape& convolution_shape = convolution->shape(); int64 output_rows = - convolution_shape.dimensions(dnums.spatial_dimensions(0)); - int64 output_cols = - one_dim_convolution - ? 1 - : convolution_shape.dimensions(dnums.spatial_dimensions(1)); + convolution_shape.dimensions(dnums.output_spatial_dimensions(0)); + int64 output_cols = one_dim_convolution + ? 1 + : convolution_shape.dimensions( + dnums.output_spatial_dimensions(1)); // Extract the window stride for the convolution. const Window& window = convolution->window(); @@ -1068,10 +993,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { return EmitTargetElementLoop( convolution, [this, convolution, lhs, rhs, window, dnums](const llvm_ir::IrArray::Index& index) { - int num_spatial_dims = dnums.spatial_dimensions_size(); + int num_spatial_dims = dnums.output_spatial_dimensions_size(); std::vector output_spatial(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { - output_spatial[i] = index[dnums.spatial_dimensions(i)]; + output_spatial[i] = index[dnums.output_spatial_dimensions(i)]; } llvm::Value* output_feature = index[dnums.output_feature_dimension()]; llvm::Value* batch = index[dnums.output_batch_dimension()]; @@ -1091,8 +1016,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { for (int i = 0; i < num_spatial_dims; ++i) { kernel_spatial[i] = loops - .AddLoop(0, rhs->shape().dimensions( - dnums.kernel_spatial_dimensions(i)), + .AddLoop(0, + rhs->shape().dimensions( + dnums.kernel_spatial_dimensions(i)), tensorflow::strings::StrCat("k", i)) ->GetIndVarValue(); } @@ -1108,17 +1034,18 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // Calculate the spatial index in the input array, taking striding, // dilation and padding into account. An index in the padding will be // out of the bounds of the array. - const auto calculate_input_index = [this]( - llvm::Value* output_index, llvm::Value* kernel_index, - const WindowDimension& window_dim) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - output_index, ir_builder_.getInt64(window_dim.stride())); - llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul( - kernel_index, ir_builder_.getInt64(window_dim.window_dilation())); - return ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index), - ir_builder_.getInt64(window_dim.padding_low())); - }; + const auto calculate_input_index = + [this](llvm::Value* output_index, llvm::Value* kernel_index, + const WindowDimension& window_dim) { + llvm::Value* strided_index = ir_builder_.CreateNSWMul( + output_index, ir_builder_.getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul( + kernel_index, + ir_builder_.getInt64(window_dim.window_dilation())); + return ir_builder_.CreateNSWSub( + ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index), + ir_builder_.getInt64(window_dim.padding_low())); + }; std::vector input_spatial(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { input_spatial[i] = calculate_input_index( @@ -1140,11 +1067,11 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0)); }; - llvm::Value* in_bounds_condition = nullptr; + llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); for (int i = 0; i < num_spatial_dims; ++i) { llvm::ConstantInt* input_bound = ir_builder_.getInt64(window_util::DilatedBound( - lhs->shape().dimensions(dnums.spatial_dimensions(i)), + lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), window.dimensions(i).base_dilation())); llvm::Value* dim_in_bound = ir_builder_.CreateICmpULT(input_spatial[i], input_bound); @@ -1153,9 +1080,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { llvm::Value* dim_ok = ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole); in_bounds_condition = - in_bounds_condition - ? ir_builder_.CreateAnd(in_bounds_condition, dim_ok) - : dim_ok; + ir_builder_.CreateAnd(in_bounds_condition, dim_ok); } // Now we need to map the dilated base coordinates back to the actual @@ -1178,7 +1103,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { int num_dims = num_spatial_dims + 2; llvm_ir::IrArray::Index input_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { - input_index[dnums.spatial_dimensions(i)] = input_spatial[i]; + input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; } input_index[dnums.input_feature_dimension()] = input_feature; input_index[dnums.input_batch_dimension()] = batch; @@ -1449,7 +1374,7 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { // // Where Param is the actual element type of the underlying buffer (for // example, float for an XLA F32 element type). - llvm::Argument* params = GetArg(compute_function_, 2); + llvm::Value* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); llvm::LoadInst* param_address_untyped = @@ -1587,7 +1512,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( // Here we assume that the largest register is a vector register. int max_vector_register_size_in_bytes = target_machine_features_.largest_register_size_in_bytes( - compute_function_); + compute_function_->function()); int vector_register_size_in_elements = max_vector_register_size_in_bytes / @@ -1745,19 +1670,6 @@ void IrEmitter::EmitShardedVectorStore( } } -namespace { -// TODO(sanjoy): This is duplicated in tensorflow/core/lib/core/arena.cc. -// Extract out a common implementation to tensorflow/core/lib/math/math_util.h -uint32 GCD(uint32 x, uint32 y) { - while (y != 0) { - uint32 r = x % y; - x = y; - y = r; - } - return x; -} -} // namespace - StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function, @@ -1780,9 +1692,9 @@ StatusOr IrEmitter::EmitVectorizedReduce( std::find(dimensions.begin(), dimensions.end(), arg->shape().layout().minor_to_major(0)) != dimensions.end(); - unsigned element_alignment = - GCD(ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), - MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); + unsigned element_alignment = tensorflow::MathUtil::GCD( + ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), + MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); if (is_reduction_over_minor_dimension) { // TODO(sanjoy): Implement vectorized reduction over the minor dimension. @@ -1983,11 +1895,16 @@ Status IrEmitter::HandleSend(HloInstruction* send) { return Unimplemented("Send is not implemented on CPU. See b/33942983."); } +Status IrEmitter::HandleSendDone(HloInstruction* send_done) { + // TODO(b/33942983): Support Send/Recv on CPU. + return Unimplemented("Send-done is not implemented on CPU. See b/33942983."); +} + Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); // The code below emits a sequential loop nest. For the parallel backend, use - // EmitParallelTargetElementLoop() which respects dynamic loop bounds. + // ParallelLoopEmitter which respects dynamic loop bounds. if (ShouldEmitParallelLoopFor(*slice)) { return DefaultAction(slice); } @@ -2148,6 +2065,11 @@ Status IrEmitter::HandleRecv(HloInstruction* recv) { return Unimplemented("Recv is not implemented on CPU. See b/33942983."); } +Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) { + // TODO(b/33942983): Support Send/Recv on CPU. + return Unimplemented("Recv-done is not implemented on CPU. See b/33942983."); +} + Status IrEmitter::HandlePad(HloInstruction* pad) { // CPU backend does not properly handle negative padding but this is ok // because negative padding should be removed by the algebraic simplifier. @@ -2250,8 +2172,8 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( *root, root->operand(0)->IsRank2Transpose(), root->operand(1)->IsRank2Transpose(), target_array, lhs_array, - rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_)); + rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), + &ir_builder_, hlo_module_config_)); return Status::OK(); } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { @@ -2272,6 +2194,35 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); + } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kOutput) { + VLOG(3) << "HandleFusion kOutput"; + int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1; + const HloInstruction* dot = root->operand(dot_op_index); + CHECK_EQ(dot->opcode(), HloOpcode::kDot) + << dot->ToString() << " " + << fusion->fused_instructions_computation()->ToString(); + + int64 dot_lhs_param_number = dot->operand(0)->parameter_number(); + int64 dot_rhs_param_number = dot->operand(1)->parameter_number(); + int64 addend_param_number = + root->operand(1 - dot_op_index)->parameter_number(); + + Shape target_shape = fusion->shape(); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); + llvm_ir::IrArray target_array = GetIrArrayFor(fusion); + + llvm_ir::IrArray lhs_array( + GetIrArrayFor(fusion->operand(dot_lhs_param_number))); + llvm_ir::IrArray rhs_array( + GetIrArrayFor(fusion->operand(dot_rhs_param_number))); + llvm_ir::IrArray addend_array( + GetIrArrayFor(fusion->operand(addend_param_number))); + + TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( + *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, + lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(), + &ir_builder_, hlo_module_config_)); + return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); } @@ -2292,9 +2243,17 @@ Status IrEmitter::HandleCall(HloInstruction* call) { !parallel_cpu_backend_) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. - TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses, - emitted_value_[call], computation, - call_ir_function)); + std::vector call_args = GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, computation->name(), + /*return_value_buffer=*/emitted_value_[call], + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument()); + + HloInstruction* root = computation->root_instruction(); + TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin( + call_args, root->shape(), root->outer_dimension_partitions(), + &ir_builder_, call_ir_function, computation->name())); } else { EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, emitted_value_[call], computation->name()); @@ -2397,7 +2356,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Terminates the current block with a branch to a while header. llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), - compute_function_); + compute_function_->function()); ir_builder_.CreateBr(header_bb); ir_builder_.SetInsertPoint(header_bb); @@ -2414,7 +2373,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Branches to the body or to the while exit depending on the condition. llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "body")), - compute_function_); + compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb); @@ -2429,7 +2388,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { ir_builder_.CreateBr(header_bb); // Adds the exit block to the function and sets the insert point there. - compute_function_->getBasicBlockList().push_back(exit_bb); + compute_function_->function()->getBasicBlockList().push_back(exit_bb); ir_builder_.SetInsertPoint(exit_bb); return Status::OK(); @@ -2547,7 +2506,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, const llvm_ir::IrArray& source_array) { unsigned primitive_type_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - unsigned element_alignment = GCD( + unsigned element_alignment = tensorflow::MathUtil::GCD( primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)); llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual( llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); @@ -2594,6 +2553,65 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { return DefaultAction(concatenate); } +Status IrEmitter::HandleConditional(HloInstruction* conditional) { + auto pred = conditional->operand(0); + auto true_arg = conditional->operand(1); + auto false_arg = conditional->operand(2); + TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && + pred->shape().element_type() == PRED) + << "Predicate on a Conditional must be bool; got: " + << ShapeUtil::HumanString(pred->shape()); + + HloComputation* true_computation = conditional->true_computation(); + HloComputation* false_computation = conditional->false_computation(); + TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), + true_computation->root_instruction()->shape())) + << "Shape of conditional should be same as the shape of the true " + << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) + << " and " + << ShapeUtil::HumanString(true_computation->root_instruction()->shape()); + + TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), + false_computation->root_instruction()->shape())) + << "Shape of conditional should be same as the shape of the false " + << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) + << " and " + << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); + + llvm::Function* true_function = + FindOrDie(emitted_functions_, true_computation); + llvm::Function* false_function = + FindOrDie(emitted_functions_, false_computation); + + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); + llvm::Value* conditional_result = GetEmittedValueFor(conditional); + + // Generating: + // if (pred) + // cond_result = true_computation(true_operand) + // else + // cond_result = false_computation(false_operand) + llvm::LoadInst* pred_value = ir_builder_.CreateLoad( + GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ir_builder_.CreateICmpNE( + pred_value, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), + "boolean_predicate"); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(pred_cond, "conditional", &ir_builder_); + + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, + conditional_result, IrName(conditional, "_true")); + + SetToFirstInsertPoint(if_data.false_block, &ir_builder_); + EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, + conditional_result, IrName(conditional, "_false")); + + SetToFirstInsertPoint(if_data.after_block, &ir_builder_); + return Status::OK(); +} + Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding @@ -2605,53 +2623,56 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { llvm::Value* root_value = GetEmittedValueFor(root); VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value); - // For the parallel cpu backend, we record the total for each embedded - // computation callee with its caller kCall HLO. - HloInstruction* hlo_to_lookup = nullptr; - if (parallel_cpu_backend_ && is_top_level_computation_) { - auto* computation = root->parent(); - auto* entry_computation = computation->parent()->entry_computation(); - if (computation != entry_computation) { - for (HloInstruction* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCall && - instruction->to_apply()->root_instruction() == root) { - hlo_to_lookup = instruction; - break; + llvm::Value* prof_counter = [&]() { + // For the parallel cpu backend, we record the total for each embedded + // computation callee with its caller kCall HLO. + if (parallel_cpu_backend_ && is_top_level_computation_) { + auto* computation = root->parent(); + auto* entry_computation = computation->parent()->entry_computation(); + if (computation != entry_computation) { + for (HloInstruction* instruction : entry_computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCall && + instruction->to_apply()->root_instruction() == root) { + return GetProfileCounterFor(*instruction); + } } } } - } - if (auto* prof_counter = GetProfileCounterFor(hlo_to_lookup)) { + + // Otherwise we record the total computation cycles in a dedicated slot for + // the entry computation. + return GetProfileCounterForEntryComputation(); + }(); + + if (prof_counter) { profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); } - - ir_builder_.CreateRetVoid(); return Status::OK(); } -llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction* hlo) { - string counter_name; - size_t prof_counter_idx; - if (!hlo_to_profile_idx_) { +llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction& hlo) { + auto it = hlo_to_profile_idx_.find(&hlo); + if (it == hlo_to_profile_idx_.end()) { return nullptr; } - if (hlo) { - auto it = hlo_to_profile_idx_->find(hlo); - if (it == hlo_to_profile_idx_->end()) { - return nullptr; - } - prof_counter_idx = it->second; - counter_name = IrName("prof_counter", hlo->name()); - } else { - prof_counter_idx = hlo_to_profile_idx_->size(); - counter_name = "prof_counter.computation"; - } + size_t prof_counter_idx = it->second; + string counter_name = IrName("prof_counter", hlo.name()); return ir_builder_.CreateGEP(GetProfileCountersArgument(), ir_builder_.getInt64(prof_counter_idx), AsStringRef(counter_name)); } +llvm::Value* IrEmitter::GetProfileCounterForEntryComputation() { + if (entry_computation_profile_idx_) { + return ir_builder_.CreateGEP( + GetProfileCountersArgument(), + ir_builder_.getInt64(*entry_computation_profile_idx_), + "prof_counter.computation"); + } + return nullptr; +} + void IrEmitter::ProfilingState::UpdateProfileCounter( llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter, llvm::Value* cycle_end, llvm::Value* cycle_start) { @@ -2723,14 +2744,14 @@ void IrEmitter::ProfilingState::RecordCompleteComputation( Status IrEmitter::Preprocess(HloInstruction* hlo) { VLOG(3) << "Visiting: " << hlo->ToString(); - if (hlo_to_profile_idx_ && hlo_to_profile_idx_->count(hlo)) { + if (hlo_to_profile_idx_.count(hlo)) { profiling_state_.RecordCycleStart(&ir_builder_, hlo); } return Status::OK(); } Status IrEmitter::Postprocess(HloInstruction* hlo) { - if (auto* prof_counter = GetProfileCounterFor(hlo)) { + if (auto* prof_counter = GetProfileCounterFor(*hlo)) { profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter); } return Status::OK(); @@ -2766,45 +2787,16 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { return llvm_ir::ShapeToIrType(shape, module_); } -std::vector IrEmitter::GetComputeFunctionParams() { - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); - llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); - std::vector compute_function_params( - {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); - if (num_dynamic_loop_bounds_ > 0) { - compute_function_params.push_back(i64_ptr_type); - } - if (hlo_to_profile_idx_) { - compute_function_params.push_back(i64_ptr_type); - } - return compute_function_params; -} - -llvm::Argument* IrEmitter::GetResultArgument() { - return GetArg(compute_function_, 0); -} - -llvm::Argument* IrEmitter::GetProfileCountersArgument() { - const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; - return hlo_to_profile_idx_ ? GetArg(compute_function_, arg_index) : nullptr; +llvm::Value* IrEmitter::GetProfileCountersArgument() { + return compute_function_->profile_counters_arg(); } llvm::Value* IrEmitter::GetTempBuffersArgument() { - return GetArg(compute_function_, 3); -} - -llvm::Value* IrEmitter::GetDynamicLoopBound(const int64 offset) { - CHECK_GT(num_dynamic_loop_bounds_, 0); - CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); - llvm::Argument* loop_bounds_arg = GetArg(compute_function_, 4); - string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); - return ir_builder_.CreateLoad(ir_builder_.CreateGEP( - loop_bounds_arg, ir_builder_.getInt64(offset), AsStringRef(name))); + return compute_function_->temp_buffers_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { - return GetArg(compute_function_, 1); + return compute_function_->exec_run_options_arg(); } llvm::Value* IrEmitter::EmitTempBufferPointer( @@ -2869,42 +2861,6 @@ llvm::Value* IrEmitter::EmitElementFunctionCall( AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); } -// Emits code to allocate an array of parameter address pointers, and store -// each address from 'parameter_addresses'. -// Returns an array of compute function call arguments (including parameter -// address buffer). -std::vector IrEmitter::GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name) { - llvm::Value* parameter_addresses_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt8PtrTy(), - ir_builder_.getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), - &ir_builder_); - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast( - parameter_addresses[i], ir_builder_.getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, - "_address_as_i8ptr"))); - llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP( - parameter_addresses_buffer, {ir_builder_.getInt64(i)}); - ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses); - } - - const auto to_int8_ptr = [this](llvm::Value* ptr) { - return ir_builder_.CreatePointerCast(ptr, ir_builder_.getInt8PtrTy()); - }; - std::vector arguments{ - to_int8_ptr(return_value_buffer), - to_int8_ptr(GetExecutableRunOptionsArgument()), - parameter_addresses_buffer, GetTempBuffersArgument()}; - if (auto* profile_counters = GetProfileCountersArgument()) { - arguments.push_back(profile_counters); - } - return arguments; -} - // Emits a core function call based on the following pseudo-code. // // char** parameter_addresses_buffer = @@ -2920,8 +2876,12 @@ void IrEmitter::EmitArrayFunctionCallInto( tensorflow::gtl::ArraySlice parameter_addresses, llvm::Value* return_value_buffer, tensorflow::StringPiece name) { ir_builder_.CreateCall( - function, GetArrayFunctionCallArguments(parameter_addresses, - return_value_buffer, name)); + function, GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); } llvm::Value* IrEmitter::EmitArrayFunctionCall( @@ -2941,117 +2901,13 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( return return_value_buffer; } -// Emits a call to a runtime fork/join function which dispatches parallel -// calls to 'parallel_function' (and joins threads before returning). -Status IrEmitter::EmitParallelForkJoin( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function) { - HloInstruction* root = computation->root_instruction(); - - // Build ParallelForkJoin function type. - std::vector compute_function_params = GetComputeFunctionParams(); - // Number of parallel compute functions. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Array of partitions. There is an array element for each - // partition x partition_dim x 2 (for dimension start and limit). - compute_function_params.push_back( - llvm::Type::getInt64PtrTy(module_->getContext())); - // Number of partitioned most-major dimensions in 'root.shape'. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Function pointer for compute function to be dispatched in parallel. - compute_function_params.push_back( - llvm::Type::getInt8PtrTy(module_->getContext())); - - llvm::FunctionType* fork_join_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/compute_function_params, - /*isVarArg=*/false); - - llvm::Function* fork_join_func = - llvm::cast(module_->getOrInsertFunction( - runtime::kParallelForkJoinSymbolName, fork_join_type)); - fork_join_func->setCallingConv(llvm::CallingConv::C); - fork_join_func->setDoesNotThrow(); - - // Add common compute function arguments. - const string name = computation->name(); - std::vector arguments = - GetArrayFunctionCallArguments(parameter_addresses, output_address, name); - - // Create ShapePartitionIterator to generate all partitions of 'root.shape'. - ShapePartitionIterator partition_iterator(root->shape(), - root->outer_dimension_partitions()); - const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); - // Add argument specifying the number of parallel partitions. - arguments.push_back(ir_builder_.getInt32(num_partitions)); - - // The number of partitioned most-major dimensions in 'root.shape'. - const int32 num_partitioned_dims = root->outer_dimension_partitions().size(); - // A dimension partition consists of two elements: [start_index, limit_index). - const int32 dim_partition_size = 2; - // Calculate array partition stride. - const int32 array_partition_stride = - num_partitioned_dims * dim_partition_size; - // Calculate the total number of elements in the partition array. - const int32 partition_array_size = - dim_partition_size * num_partitioned_dims * num_partitions; - - // Store dimension partition values as llvm constants in 'partitions'. - // See comments in runtime_fork_join.cc for array layout description. - std::vector partitions(partition_array_size); - for (int32 i = 0; i < num_partitions; ++i) { - std::vector> dim_partitions = - partition_iterator.GetPartition(i); - CHECK_EQ(num_partitioned_dims, dim_partitions.size()); - const int32 partition_index = i * array_partition_stride; - for (int32 j = 0; j < num_partitioned_dims; ++j) { - const std::pair& dim_partition = dim_partitions[j]; - const int32 index = partition_index + j * dim_partition_size; - // Store partition [dim_start, dim_limit) intervals for each dimension. - partitions[index] = ir_builder_.getInt64(dim_partition.first); - partitions[index + 1] = - ir_builder_.getInt64(dim_partition.first + dim_partition.second); - } - } - - // Create global variable out of dimension partitions in 'partitions'. - llvm::ArrayType* partitions_array_type = - llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size); - llvm::Constant* partitions_array = - llvm::ConstantArray::get(partitions_array_type, partitions); - llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/partitions_array_type, - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/partitions_array, - /*Name=*/ - AsStringRef( - tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); - - // Add argument specifying parallel dimension partitions. - arguments.push_back(ir_builder_.CreateBitCast( - global_partitions_array, - llvm::Type::getInt64PtrTy(module_->getContext()))); - // Add argument specifying the number of partitioned most-major dimensions. - arguments.push_back(ir_builder_.getInt32(num_partitioned_dims)); - // Add argument for parallel compute function pointer. - arguments.push_back( - ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy())); - // Emit call to parallel fork/join. - ir_builder_.CreateCall(fork_join_func, arguments); - - return Status::OK(); -} - Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { llvm::Value* addr; const Shape& target_shape = op->shape(); if (op == op->parent()->root_instruction()) { // For the root node, we write directly to the output buffer of the // function. - llvm::Argument* retval = GetResultArgument(); + llvm::Argument* retval = compute_function_->result_arg(); if (!ShapeUtil::IsNil(target_shape)) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); @@ -3112,8 +2968,13 @@ Status IrEmitter::EmitTargetElementLoop( } else { if (ShouldEmitParallelLoopFor(*target_op)) { - TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( - target_shape, element_generator, IrName(target_op), &target_array)); + // Emit code to read dynamic loop bounds from compute function argument. + std::vector> dynamic_loop_bounds = + compute_function_->GetDynamicLoopBounds(); + // Emit parallel loop with dynamic loop bounds for most-major dimensions. + TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, + &dynamic_loop_bounds, &ir_builder_) + .EmitLoop(IrName(target_op))); } else { TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) @@ -3123,60 +2984,6 @@ Status IrEmitter::EmitTargetElementLoop( return Status::OK(); } -Status IrEmitter::EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array) { - CHECK(!ShapeUtil::IsTuple(target_shape)); - CHECK(!ShapeUtil::IsScalar(target_shape)); - - // Emit code to read dynamic loop bounds from function argument 4. - std::vector dynamic_loop_bounds(2 * num_dynamic_loop_bounds_); - for (int i = 0; i < 2 * num_dynamic_loop_bounds_; ++i) { - dynamic_loop_bounds[i] = GetDynamicLoopBound(i); - } - - llvm_ir::ForLoopNest loop_nest(loop_name, &ir_builder_); - const int64 num_dims = target_shape.dimensions_size(); - llvm_ir::IrArray::Index array_index(num_dims); - - // Add loops from outer-most to inner-most dimensions. - for (int i = target_shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - const int64 dimension = target_shape.layout().minor_to_major(i); - const int bounds_index = num_dims - 1 - i; - if (bounds_index < num_dynamic_loop_bounds_) { - // Emit dynamic loop bounds for this dimension. Dynamic loop bounds - // are read from ir function dynamic loop bounds argument. - llvm::Value* start_index = dynamic_loop_bounds[bounds_index * 2 + 0]; - llvm::Value* end_index = dynamic_loop_bounds[bounds_index * 2 + 1]; - - std::unique_ptr loop = loop_nest.AddLoop( - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), - start_index, end_index); - array_index[dimension] = loop->GetIndVarValue(); - } else { - // Emit static loop bounds for this dimension. - std::unique_ptr loop = loop_nest.AddLoop( - /*start_index=*/0, - /*end_index=*/target_shape.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); - array_index[dimension] = loop->GetIndVarValue(); - } - } - // Point IR builder at inner loop BB. - SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - // Emit loop body. - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - element_generator(array_index)); - target_array->EmitWriteArrayElement(array_index, target_element, - &ir_builder_); - // Point IR builder at outer loop exit BB. - SetToFirstInsertPoint(loop_nest.GetOuterLoopExitBasicBlock(), &ir_builder_); - - return Status::OK(); -} - Status IrEmitter::EmitMemcpy(const HloInstruction& source, const HloInstruction& destination) { llvm::Value* source_value = GetEmittedValueFor(&source); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 5d061e11e3c9e07bdcfdc749711e4369ec2bea2a..9bc2d9739757168562b8dc7b482eff203f303766 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -30,6 +31,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -105,15 +107,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { // llvm_module: the LLVM module to emit IR into. // hlo_to_profile_idx: the mapping from HLO to its index in the profiling // array. + // entry_computation_profile_idx: the index in the profiling array + // for the entry computation. // external_constant_pool: if non-null, points to an ExternalConstantPool // instance into which the Ir emitter can spill // constants. - IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, - const std::unordered_map* - hlo_to_profile_idx, - llvm::TargetMachine* target_machine, - ExternalConstantPool* external_constant_pool); + IrEmitter( + const HloModule& hlo_module, const BufferAssignment& assignment, + llvm::Module* llvm_module, + std::unordered_map hlo_to_profile_idx, + tensorflow::gtl::optional entry_computation_profile_idx, + llvm::TargetMachine* target_machine, + ExternalConstantPool* external_constant_pool); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -171,11 +176,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleReduceWindow(HloInstruction* reduce_window) override; Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override; Status HandleSend(HloInstruction* send) override; + Status HandleSendDone(HloInstruction* send_done) override; Status HandleSlice(HloInstruction* slice) override; Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; Status HandleRecv(HloInstruction* recv) override; + Status HandleRecvDone(HloInstruction* recv_done) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple(HloInstruction* tuple) override; Status HandleMap(HloInstruction* map) override; @@ -184,6 +191,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; + Status HandleConditional(HloInstruction* conditional) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -195,7 +203,12 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to generate a GEP into the profile counter parameter // which would correspond to the index for a given HLO. - llvm::Value* GetProfileCounterFor(const HloInstruction* hlo); + llvm::Value* GetProfileCounterFor(const HloInstruction& hlo); + + // Convenience function to generate a GEP into the profile counter parameter + // corresponding to the index for the entry computation. Returns nullptr if + // profiling the entry computation is disabled. + llvm::Value* GetProfileCounterForEntryComputation(); // Gets the IR Value emitted previously for the given hlo. // @@ -223,16 +236,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to get the IR type matching the given shape. llvm::Type* IrShapeType(const Shape& shape); - // Returns an array of compute function parameter types. - std::vector GetComputeFunctionParams(); - - // Get the llvm::Value* that represents the "retval" argument of the - // computation function being emitted by this emitter. - llvm::Argument* GetResultArgument(); - // Get the llvm::Value* that represents the "prof_counters" argument of the // computation function being emitted by this emitter. - llvm::Argument* GetProfileCountersArgument(); + llvm::Value* GetProfileCountersArgument(); // Get the xla::ExecutableRunOptions that represents the "run_options" // argument of the computation function being emitted by this emitter. @@ -242,11 +248,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Emit ir to read and return the ir value for the dynamic loop bound at - // 'offset' from the "dynamic_loop_bounds" argument of the computation - // function being emitted by this emitter. - llvm::Value* GetDynamicLoopBound(const int64 offset); - // Emits code that computes the address of the given temporary buffer to the // function. target_shape is the shape of this temporary buffer. // The returned Value's type is a pointer to element_type. @@ -300,18 +301,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name); - // Returns an array of compute function call arguments. - std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name); - - // Emits a call to a runtime fork/join function which dispatches parallel - // calls to 'parallel_function' (and joins threads before returning). - Status EmitParallelForkJoin( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function); - // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. Status ElementTypesSameAndSupported( @@ -336,15 +325,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, tensorflow::StringPiece desc, const llvm_ir::ElementGenerator& element_generator); - // Emit IR to perform a computation for every element in a partition/slice of - // 'target_shape'. The loop bounds for the outer-dimension partitions are - // passed into the compute function as a runtime argument (accessible from - // GetDynamicLoopBound). - Status EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array); - // Emits a memcpy from the source instruction's result value to the // destination's. Both source and destination must have an entry in the // emitted_value_ table. @@ -466,12 +446,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory - // management rules, their memory is owned by the module. - llvm::Function* compute_function_; + // management rules, their memory is owned by the module (Note that IrFunction + // creates the encapsulated llvm::Function s.t. it is added to the llvm + // module's function list). + std::unique_ptr compute_function_; llvm::IRBuilder<> ir_builder_; // Maps HLOs to their index into the profile counter array. - const std::unordered_map* hlo_to_profile_idx_; + std::unordered_map hlo_to_profile_idx_; + const tensorflow::gtl::optional entry_computation_profile_idx_; // Maps HLOs to Values emitted for them. std::unordered_map emitted_value_; @@ -479,7 +462,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm_ir::AliasAnalysis alias_analysis_; // The number of root instruction outer dimensions used in parallel loop - // emission (EmitParallelTargetElementLoop). + // emission (ParallelLoopEmitter). int64 num_dynamic_loop_bounds_ = 0; // Returns whether the given instruction should be emitted as a parallel loop. @@ -499,7 +482,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { use_rdtscp_(false), prof_counters_(nullptr) {} ProfilingState(bool is_top_level_computation, bool use_rdtscp, - llvm::Argument* prof_counters) + llvm::Value* prof_counters) : is_top_level_computation_(is_top_level_computation), use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} @@ -532,7 +515,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { bool use_rdtscp_; // The argument which corresponds to the profile counter buffer. - llvm::Argument* prof_counters_; + llvm::Value* prof_counters_; // The first read cycle counter in the program. llvm::Value* first_read_cycle_start_ = nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca8c290dd1c4959e42026c3917d37f8fc95a1011 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -0,0 +1,333 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" + +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +namespace { +using llvm_ir::AsStringRef; +} // namespace + +namespace cpu { + +static std::vector GetComputeFunctionParams( + llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) { + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = + llvm::Type::getInt64PtrTy(llvm_module->getContext()); + std::vector compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds > 0) { + compute_function_params.push_back(i64_ptr_type); + } + compute_function_params.push_back(i64_ptr_type); + return compute_function_params; +} + +IrFunction::IrFunction(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, + int64 num_dynamic_loop_bounds) + : ir_builder_(ir_builder), + llvm_module_(llvm_module), + caller_insert_point_guard_(*ir_builder), + num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { + Initialize(function_name, linkage, optimize_for_size_requested, + enable_fast_math); +} + +IrFunction::~IrFunction() { + // Emit function return value. + ir_builder_->CreateRetVoid(); +} + +DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { + DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_); + for (int i = 0; i < num_dynamic_loop_bounds_; ++i) { + dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0); + dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1); + } + return dynamic_loop_bounds; +} + +void IrFunction::Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math) { + // The function signature is: + // void function(i8* retval, i8* run_options, i8** params, i8** temps, + // i64* dynamic_loop_bounds, i64* prof_counters) + // + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: address of an array with pointers to temporary buffers. + // + // Therefore, the generated function's signature (FunctionType) is statically + // determined - parameter unpacking is done in code generated into the + // function, rather than by a prologue dictated by the platform ABI. + // + // /--------------\ + // retval ----------> | return value | + // \--------------/ + // + // /-------------------------------\ + // run_options -----> | xla::ExecutableRunOptions | + // \-------------------------------/ + // + // /---------------------------------------------\ + // params --------> | param 0 | param 1 | ..... | param N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | param 0 | | param 1 | | param N-1 | + // \---------/ \---------/ \-----------/ + // + // /---------------------------------------------\ + // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | temp 0 | | temp 1 | | temp N-1 | + // \---------/ \---------/ \-----------/ + // + // /--------------------------------------------\ + // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| + // (elided for aot) \--------------------------------------------/ + // + // /---------------------------------------------\ + // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | + // \---------------------------------------------/ + + // Even though the type of params and temps is void** in the host's view, in + // LLVM IR this is represented by i8*, similarly to void*. It's up to the code + // to use GEPs to unravel the indirection layers. + llvm::FunctionType* function_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), + /*Params=*/ + GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_), + /*isVarArg=*/false); + + // Functions with local linkage get an inlining bonus. Because we know + // a-priori that embedded functions (non-entry functions) will not have its + // name resolved, give it local linkage. + function_ = + llvm_ir::CreateFunction(function_type, linkage, + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size_requested, + function_name, llvm_module_); + + // Set meaningful names for the function's arguments: useful for debugging. + llvm::Function::arg_iterator arg_iter = function_->arg_begin(); + arg_iter->setName("retval"); + result_arg_ = &*arg_iter; + (++arg_iter)->setName("run_options"); + exec_run_options_arg_ = &*arg_iter; + (++arg_iter)->setName("params"); + parameters_arg_ = &*arg_iter; + (++arg_iter)->setName("temps"); + temp_buffers_arg_ = &*arg_iter; + if (num_dynamic_loop_bounds_ > 0) { + (++arg_iter)->setName("dynamic_loop_bounds"); + dynamic_loop_bounds_arg_ = &*arg_iter; + } + (++arg_iter)->setName("prof_counters"); + profile_counters_arg_ = &*arg_iter; + + // We know a-priori that the function arguments are guaranteed to point to + // disjoint objects. + llvm::Argument* retval = result_arg(); + for (llvm::Argument& argument : function_->args()) { + // However, the return buffer aliases the temporaries and thus cannot be + // marked noalias. + if (&argument == retval) { + continue; + } + function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias); + } + + ir_builder_->SetInsertPoint(llvm::BasicBlock::Create( + /*Context=*/llvm_module_->getContext(), + /*Name=*/"entry", + /*Parent=*/function_)); +} + +llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { + CHECK_GT(num_dynamic_loop_bounds_, 0); + CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); + string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + return ir_builder_->CreateLoad( + ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), + ir_builder_->getInt64(offset), AsStringRef(name))); +} + +// Emits code to allocate an array of parameter address pointers, and store +// each address from 'parameter_addresses'. +// Returns an array of compute function call arguments (including parameter +// address buffer). +std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { + llvm::Value* parameter_addresses_buffer = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + ir_builder->getInt8PtrTy(), + ir_builder->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), + ir_builder); + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast( + parameter_addresses[i], ir_builder->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); + llvm::Value* slot_in_param_adresses = ir_builder->CreateInBoundsGEP( + parameter_addresses_buffer, {ir_builder->getInt64(i)}); + ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_adresses); + } + + const auto to_int8_ptr = [=](llvm::Value* ptr) { + return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy()); + }; + std::vector arguments{ + to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), + parameter_addresses_buffer, temp_buffers_arg}; + if (profile_counters_arg != nullptr) { + arguments.push_back(profile_counters_arg); + } + return arguments; +} + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector& arguments, const Shape& shape, + const std::vector& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + + // Build ParallelForkJoin function type. + std::vector compute_function_params = + GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0); + // Number of parallel compute functions. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Array of partitions. There is an array element for each + // partition x partition_dim x 2 (for dimension start and limit). + compute_function_params.push_back( + llvm::Type::getInt64PtrTy(module->getContext())); + // Number of partitioned most-major dimensions in 'shape'. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Function pointer for compute function to be dispatched in parallel. + compute_function_params.push_back( + llvm::Type::getInt8PtrTy(module->getContext())); + + llvm::FunctionType* fork_join_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(module->getContext()), + /*Params=*/compute_function_params, + /*isVarArg=*/false); + + llvm::Function* fork_join_func = + llvm::cast(module->getOrInsertFunction( + runtime::kParallelForkJoinSymbolName, fork_join_type)); + fork_join_func->setCallingConv(llvm::CallingConv::C); + fork_join_func->setDoesNotThrow(); + + // Add common compute function arguments. + std::vector fork_join_arguments(arguments); + + // Create ShapePartitionIterator to generate all partitions of 'shape'. + ShapePartitionIterator partition_iterator(shape, dimension_partition_counts); + const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); + // Add argument specifying the number of parallel partitions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitions)); + + // The number of partitioned most-major dimensions in 'shape'. + const int32 num_partitioned_dims = dimension_partition_counts.size(); + // A dimension partition consists of two elements: [start_index, limit_index). + const int32 dim_partition_size = 2; + // Calculate array partition stride. + const int32 array_partition_stride = + num_partitioned_dims * dim_partition_size; + // Calculate the total number of elements in the partition array. + const int32 partition_array_size = + dim_partition_size * num_partitioned_dims * num_partitions; + + // Store dimension partition values as llvm constants in 'partitions'. + // See comments in runtime_fork_join.cc for array layout description. + std::vector partitions(partition_array_size); + for (int32 i = 0; i < num_partitions; ++i) { + std::vector> dim_partitions = + partition_iterator.GetPartition(i); + CHECK_EQ(num_partitioned_dims, dim_partitions.size()); + const int32 partition_index = i * array_partition_stride; + for (int32 j = 0; j < num_partitioned_dims; ++j) { + const std::pair& dim_partition = dim_partitions[j]; + const int32 index = partition_index + j * dim_partition_size; + // Store partition [dim_start, dim_limit) intervals for each dimension. + partitions[index] = ir_builder->getInt64(dim_partition.first); + partitions[index + 1] = + ir_builder->getInt64(dim_partition.first + dim_partition.second); + } + } + + // Create global variable out of dimension partitions in 'partitions'. + llvm::ArrayType* partitions_array_type = + llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size); + llvm::Constant* partitions_array = + llvm::ConstantArray::get(partitions_array_type, partitions); + llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( + /*M=*/*module, + /*Ty=*/partitions_array_type, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/partitions_array, + /*Name=*/ + AsStringRef( + tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + + // Add argument specifying parallel dimension partitions. + fork_join_arguments.push_back(ir_builder->CreateBitCast( + global_partitions_array, + llvm::Type::getInt64PtrTy(module->getContext()))); + // Add argument specifying the number of partitioned most-major dimensions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims)); + // Add argument for parallel compute function pointer. + fork_join_arguments.push_back( + ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy())); + // Emit call to parallel fork/join. + ir_builder->CreateCall(fork_join_func, fork_join_arguments); + + return Status::OK(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h new file mode 100644 index 0000000000000000000000000000000000000000..1fd2da4dce23982ed030f3aa8ec604182d0ebab8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ + +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { +namespace cpu { + +// IrFunction creates and encapsulates an llvm::Function, exposing methods to +// emitters for function and function argument access. +// The llvm::Function is created with the standard function signature +// used in the XLA CPU backend (see ir_function.cc for argument details). +// In addtion IrFunction saves the callers IR insert point during contruction, +// and restores it after desctruction. +// +// Example usage: +// +// // Create and initialize new IrFunction. +// std::unique_ptr compute_function(new IrFunction(...)); +// // Emit IR for function body using IrFunction helper methods. +// ... +// // Store reference to llvm::Function for future invocation. +// ir_functions.push_back(compute_function.function()); +// // Delete IrFunction (finalizes IR function and restores caller insertion +// // point). +// compute_function.reset(); +// + +class IrFunction { + public: + IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, int64 num_dynamic_loop_bounds); + ~IrFunction(); + + // Emit ir to read and return the set of ir values representing the dynamic + // loop bounds argument of this function. + // Each element in returned vector is a pair of ir values representing + // the loop bounds for a specific dimension, where the first element of the + // pair is the dimension start index, and the second element of the pair + // is the dimension limit. + // EX: [dimension_i_index_start_ir_value, dimension_i_index_limit_ir_value] + // + DynamicLoopBounds GetDynamicLoopBounds(); + + // Returns the encapculated llvm::Function. + llvm::Function* function() { return function_; } + + // Get the llvm::Value* that represents this functions "retval" argument. + llvm::Argument* result_arg() { return result_arg_; } + + // Get the xla::ExecutableRunOptions that represents this functions + // "run_options" argument. + llvm::Value* exec_run_options_arg() { return exec_run_options_arg_; } + + // Get the llvm::Value* that represents this functions parameters argument. + llvm::Value* parameters_arg() { return parameters_arg_; } + + // Get the llvm::Value* that represents this functions "temps" argument. + llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; } + + // Get the llvm::Value* that represents this functions "prof_counters" + // argument. + llvm::Value* profile_counters_arg() { return profile_counters_arg_; } + + private: + // Initialize an llvm::Function with standard signature based on arguments. + void Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + bool optimize_for_size_requested, bool enable_fast_math); + + // Emit ir to read and return the ir value for the dynamic loop bound at + // 'offset' from the "dynamic_loop_bounds" argument of this function. + llvm::Value* GetDynamicLoopBound(int64 offset); + + llvm::IRBuilder<>* ir_builder_; + llvm::Module* llvm_module_; + llvm::IRBuilder<>::InsertPointGuard caller_insert_point_guard_; + + int64 num_dynamic_loop_bounds_ = 0; + // Encapsulated llvm::Function. + llvm::Function* function_; + // Function argument IR values. + llvm::Argument* result_arg_; + llvm::Value* exec_run_options_arg_; + llvm::Value* parameters_arg_; + llvm::Value* temp_buffers_arg_; + llvm::Value* dynamic_loop_bounds_arg_ = nullptr; + llvm::Value* profile_counters_arg_; +}; + +// Returns an array of compute function call argument ir values. +std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector& arguments, const Shape& shape, + const std::vector& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name); + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..e624e5cc7ebdbb79a8a3b3c73633ec697a71d172 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace xla { +namespace cpu { +namespace orc_jit_memory_mapper { + +static tensorflow::mutex mapper_instance_mutex(tensorflow::LINKER_INITIALIZED); +static llvm::SectionMemoryManager::MemoryMapper* mapper_instance + GUARDED_BY(mapper_instance_mutex) = nullptr; + +llvm::SectionMemoryManager::MemoryMapper* GetInstance() { + tensorflow::mutex_lock lock(mapper_instance_mutex); + return mapper_instance; +} + +Registrar::Registrar( + std::unique_ptr mapper) { + tensorflow::mutex_lock lock(mapper_instance_mutex); + mapper_instance = mapper.release(); +} +} // namespace orc_jit_memory_mapper +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h new file mode 100644 index 0000000000000000000000000000000000000000..2d29550fd5bd659770cc6300e56b57bf1763e671 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ + +#include + +#include "llvm/ExecutionEngine/SectionMemoryManager.h" + +namespace xla { +namespace cpu { + +namespace orc_jit_memory_mapper { +// Returns the registered memory mapper if there is one. Returns nullptr if no +// memory mapper is registered. +llvm::SectionMemoryManager::MemoryMapper* GetInstance(); + +class Registrar { + public: + // Registers the `mapper` as a memory mapper. This is a no-op if `mapper` is + // null. Precondition: no other memory mapper has been registered yet. + explicit Registrar( + std::unique_ptr mapper); +}; +} // namespace orc_jit_memory_mapper + +#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(mapper_instance, ctr) \ + static ::xla::cpu::orc_jit_memory_mapper::Registrar \ + XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr)(mapper_instance) + +// __COUNTER__ must go through another macro to be properly expanded +#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr) \ + __orc_jit_memory_mapper_registrar_##ctr + +// Registers the std::unique_ptr +// returned by the `factory` expression. `factory` is allowed to evaluate to +// a null unique_ptr in which case this macro does nothing. +#define XLA_REGISTER_ORC_JIT_MEMORY_MAPPER(factory) \ + XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(factory, __COUNTER__) +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index aff61296ced47a911ded207f611747564b5ac7eb..0077e344e2bd34aa598ee076220fee678f31b4ad 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -59,19 +59,20 @@ ParallelCpuExecutable::ParallelCpuExecutable( std::unique_ptr assignment, std::unique_ptr hlo_module, std::unique_ptr> function_names, - std::unordered_map hlo_to_profile_idx, std::unordered_map> - aligned_constants) - : Executable(std::move(hlo_module)), + aligned_constants, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + std::move(hlo_profile_index_map)), jit_(std::move(jit)), assignment_(std::move(assignment)), function_names_(std::move(function_names)), - hlo_to_profile_idx_(std::move(hlo_to_profile_idx)), aligned_constants_(std::move(aligned_constants)) {} // Type of the computation function we expect in the JIT. using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, - int64*, uint64*); + int64*, int64*); // Given a pointer to an output buffer (following the CPU JIT calling // conventions), mark addresses that are "live". The initial pointer itself is @@ -106,7 +107,7 @@ class Executor { const ServiceExecutableRunOptions* run_options, std::list* pending, HloInstructionMap* results, void** temps_array, - uint64* profile_counters_array, const BufferAssignment* assignment) + int64* profile_counters_array, const BufferAssignment* assignment) : functions_(functions), run_options_(run_options), pending_(pending), @@ -147,7 +148,7 @@ class Executor { std::list* pending_; HloInstructionMap* results_; void** temps_array_; - uint64* profile_counters_array_; + int64* profile_counters_array_; tensorflow::thread::ThreadPool* thread_pool_; const BufferAssignment* assignment_; @@ -389,9 +390,11 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { // Allocate profiling counters for each hlo instruction that we would like to - // profile. Allocate an additional profile counter for the entire - // computation. - std::vector profile_counters(hlo_to_profile_idx_.size() + 1); + // profile. + std::vector* profile_counters = nullptr; + if (hlo_execution_profile) { + profile_counters = hlo_execution_profile->mutable_profile_counters(); + } std::vector buffer_pointers; buffer_pointers.reserve(buffers.size()); @@ -441,9 +444,9 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( // For example, if we expect a library conv/matmul call to run at max // concurrency, we should not dispatch runnable instructions until the // library call is finished (to avoid expensive cache invalidation). - Executor executor(functions, run_options, &pending, &results, - buffer_pointers.data(), profile_counters.data(), - assignment_.get()); + Executor executor( + functions, run_options, &pending, &results, buffer_pointers.data(), + profile_counters ? profile_counters->data() : nullptr, assignment_.get()); TF_RETURN_IF_ERROR(executor.Run()); @@ -453,18 +456,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( tensorflow::mutex_lock lock(mutex_); double nanoseconds = (end_micros - start_micros) * 1000.0; execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - // The last profile counter is used for the computation as a whole. - execution_profile_.set_compute_cycle_count(profile_counters.back()); - } - if (hlo_execution_profile != nullptr) { - hlo_execution_profile->set_total_cycles_executed(entry_computation, - profile_counters.back()); - - for (auto hlo_prof_idx : hlo_to_profile_idx_) { - const HloInstruction* hlo = hlo_prof_idx.first; - uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; - hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken); - } } return Status::OK(); @@ -618,10 +609,5 @@ const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const { module().entry_computation()->root_instruction()); } -std::unique_ptr ParallelCpuExecutable::CreateCostAnalysis() - const { - return MakeUnique(ShapeSizeBytes); -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index db16aaf48b0ef2aaa727c1bd0106bc51d1a65095..d65e3f42f3cb34eff005f34b51b81fd5c42974a3 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -52,10 +52,11 @@ class ParallelCpuExecutable : public Executable { std::unique_ptr assignment, std::unique_ptr hlo_module, std::unique_ptr> function_names, - std::unordered_map hlo_to_profile_idx, std::unordered_map> - aligned_constants); + aligned_constants, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map); ~ParallelCpuExecutable() override {} StatusOr ExecuteOnStream( @@ -95,8 +96,6 @@ class ParallelCpuExecutable : public Executable { "Equality test on CPU parallel executable is not implemented."); } - std::unique_ptr CreateCostAnalysis() const override; - private: // Allocate buffers required for execution and assign them to the elements of // "buffers". "buffers" should be sized to the number of buffers in buffer @@ -143,9 +142,6 @@ class ParallelCpuExecutable : public Executable { // Map containing the JITted function names for each HLO instruction. const std::unique_ptr> function_names_; - // Maps HLOs to their index into the profile counter array. - const std::unordered_map hlo_to_profile_idx_; - // Map from HLO Constant instructions to a pointer to their literal data. // The data stored in the protocol buffer might be insufficiently aligned, // we create a sufficiently aligned copy and store it in this map. diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..a3c3c1e5efc91af6b924a3712689f3d7ccf5d6f6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { +namespace cpu { + +ParallelLoopEmitter::ParallelLoopEmitter( + const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* ir_builder) + : LoopEmitter(target_element_generator, target_array, ir_builder), + dynamic_loop_bounds_(dynamic_loop_bounds) {} + +llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) { + CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!ShapeUtil::IsScalar(shape_)); + + llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_); + const int64 num_dims = shape_.dimensions_size(); + llvm_ir::IrArray::Index array_index(num_dims); + + // Add loops from outer-most to inner-most dimensions. + for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) { + const int64 dimension = shape_.layout().minor_to_major(i); + const int bounds_index = num_dims - 1 - i; + if (bounds_index < dynamic_loop_bounds_->size()) { + // Emit dynamic loop bounds for this dimension. Dynamic loop bounds + // are read from ir function dynamic loop bounds argument. + llvm::Value* start_index = (*dynamic_loop_bounds_)[bounds_index].first; + llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second; + + std::unique_ptr loop = loop_nest.AddLoop( + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), + start_index, end_index); + array_index[dimension] = loop->GetIndVarValue(); + } else { + // Emit static loop bounds for this dimension. + std::unique_ptr loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/shape_.dimensions(dimension), + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + array_index[dimension] = loop->GetIndVarValue(); + } + } + // Point IR builder at inner loop BB. + llvm_ir::SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), + ir_builder_); + + // Set exit_bb_ to the exit block of the loop nest. + exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); + CHECK(exit_bb_ != nullptr); + + return array_index; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..9335d2818e99eb3588537d80dabddda08c1c020e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" + +namespace xla { +namespace cpu { + +// ParallelLoopEmitter emits a loop nest for the target array shape. +// The outer loop bounds of the loop nest are passed as ir values at runtime +// (specified in 'dynamic_loop_bounds'), and the inner loop bounds are static. +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] +// +// Code emitted by ParallelLoopEmitter will be called in a multi-threaded +// context where each thread will be assigned a different set of outer dimension +// partitions, and where all threads will collectively iterate over the +// entire target array shape. +// +// Outer dimension partitions can be generated using the ShapePartitionAssigner +// and ShapePartitionIterator utility classes from shape_partition.cc. +// +class ParallelLoopEmitter : public llvm_ir::LoopEmitter { + public: + // Constructs a ParallelLoopEmitter which uses 'target_element_generator' to + // generate elements, 'dynamic_loop_bounds' to set the loop bounds of the + // most-major dimensions, and 'target_array.' shape to set the static loop + // bounds for the most-minor dimensions. + ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const DynamicLoopBounds* dynamic_loop_bounds, + llvm::IRBuilder<>* ir_builder); + + ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; + ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; + ~ParallelLoopEmitter() override = default; + + llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) override; + + private: + const DynamicLoopBounds* dynamic_loop_bounds_; +}; + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 4a62a80fac0c89d8e1cf66f16f07fca0ffbaa2d3..4b44ac8941e222d5954121bbb9654062e41f55d6 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index fdf02e5b422f75e256feec77470bb0d079e8ef1f..cda2783307925b77ac6d8cfe679c5b325db2befc 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" @@ -125,8 +126,10 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), - object_layer_( - [] { return std::make_shared(); }), + object_layer_([] { + return std::make_shared( + orc_jit_memory_mapper::GetInstance()); + }), compile_layer_( object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, @@ -210,71 +213,75 @@ bool RegisterKnownJITSymbols() { #undef REGISTER_CPU_RUNTIME_SYMBOL -#define REGISTER_LIBM_SYMBOL(name) \ - do { \ - /* Register both the F32 and F64 variants of the libm symbol. */ \ - registry->Register(#name "f", reinterpret_cast(name##f)); \ - registry->Register(#name, reinterpret_cast(name)); \ +// Register both the f32 (float) and f64 (double) versions of a libm symbol. +// Unfortunately the double versions are overloaded on some systems, e.g. +// Mac so we need an explicit cast. This requires passing the function signature +// for that case. +#define REGISTER_LIBM_SYMBOL(name, double_sig) \ + do { \ + registry->Register(#name "f", reinterpret_cast(name##f)); \ + registry->Register( \ + #name, reinterpret_cast(static_cast(name))); \ } while (false) - REGISTER_LIBM_SYMBOL(acos); - REGISTER_LIBM_SYMBOL(acosh); - REGISTER_LIBM_SYMBOL(asin); - REGISTER_LIBM_SYMBOL(asinh); - REGISTER_LIBM_SYMBOL(atan); - REGISTER_LIBM_SYMBOL(atan2); - REGISTER_LIBM_SYMBOL(atanh); - REGISTER_LIBM_SYMBOL(cbrt); - REGISTER_LIBM_SYMBOL(ceil); - REGISTER_LIBM_SYMBOL(copysign); - REGISTER_LIBM_SYMBOL(cos); - REGISTER_LIBM_SYMBOL(cosh); - REGISTER_LIBM_SYMBOL(erf); - REGISTER_LIBM_SYMBOL(erfc); - REGISTER_LIBM_SYMBOL(exp); - REGISTER_LIBM_SYMBOL(exp2); - REGISTER_LIBM_SYMBOL(expm1); - REGISTER_LIBM_SYMBOL(fabs); - REGISTER_LIBM_SYMBOL(fdim); - REGISTER_LIBM_SYMBOL(floor); - REGISTER_LIBM_SYMBOL(fma); - REGISTER_LIBM_SYMBOL(fmax); - REGISTER_LIBM_SYMBOL(fmin); - REGISTER_LIBM_SYMBOL(fmod); - REGISTER_LIBM_SYMBOL(frexp); - REGISTER_LIBM_SYMBOL(hypot); - REGISTER_LIBM_SYMBOL(ilogb); - REGISTER_LIBM_SYMBOL(ldexp); - REGISTER_LIBM_SYMBOL(lgamma); - REGISTER_LIBM_SYMBOL(llrint); - REGISTER_LIBM_SYMBOL(llround); - REGISTER_LIBM_SYMBOL(log); - REGISTER_LIBM_SYMBOL(log10); - REGISTER_LIBM_SYMBOL(log1p); - REGISTER_LIBM_SYMBOL(log2); - REGISTER_LIBM_SYMBOL(logb); - REGISTER_LIBM_SYMBOL(lrint); - REGISTER_LIBM_SYMBOL(lround); - REGISTER_LIBM_SYMBOL(modf); - REGISTER_LIBM_SYMBOL(nan); - REGISTER_LIBM_SYMBOL(nearbyint); - REGISTER_LIBM_SYMBOL(nextafter); - REGISTER_LIBM_SYMBOL(nexttoward); - REGISTER_LIBM_SYMBOL(pow); - REGISTER_LIBM_SYMBOL(remainder); - REGISTER_LIBM_SYMBOL(remquo); - REGISTER_LIBM_SYMBOL(rint); - REGISTER_LIBM_SYMBOL(round); - REGISTER_LIBM_SYMBOL(scalbln); - REGISTER_LIBM_SYMBOL(scalbn); - REGISTER_LIBM_SYMBOL(sin); - REGISTER_LIBM_SYMBOL(sincos); - REGISTER_LIBM_SYMBOL(sinh); - REGISTER_LIBM_SYMBOL(sqrt); - REGISTER_LIBM_SYMBOL(tan); - REGISTER_LIBM_SYMBOL(tanh); - REGISTER_LIBM_SYMBOL(tgamma); - REGISTER_LIBM_SYMBOL(trunc); + REGISTER_LIBM_SYMBOL(acos, double (*)(double)); + REGISTER_LIBM_SYMBOL(acosh, double (*)(double)); + REGISTER_LIBM_SYMBOL(asin, double (*)(double)); + REGISTER_LIBM_SYMBOL(asinh, double (*)(double)); + REGISTER_LIBM_SYMBOL(atan, double (*)(double)); + REGISTER_LIBM_SYMBOL(atan2, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(atanh, double (*)(double)); + REGISTER_LIBM_SYMBOL(cbrt, double (*)(double)); + REGISTER_LIBM_SYMBOL(ceil, double (*)(double)); + REGISTER_LIBM_SYMBOL(copysign, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(cos, double (*)(double)); + REGISTER_LIBM_SYMBOL(cosh, double (*)(double)); + REGISTER_LIBM_SYMBOL(erf, double (*)(double)); + REGISTER_LIBM_SYMBOL(erfc, double (*)(double)); + REGISTER_LIBM_SYMBOL(exp, double (*)(double)); + REGISTER_LIBM_SYMBOL(exp2, double (*)(double)); + REGISTER_LIBM_SYMBOL(expm1, double (*)(double)); + REGISTER_LIBM_SYMBOL(fabs, double (*)(double)); + REGISTER_LIBM_SYMBOL(fdim, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(floor, double (*)(double)); + REGISTER_LIBM_SYMBOL(fma, double (*)(double, double, double)); + REGISTER_LIBM_SYMBOL(fmax, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(fmin, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(fmod, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(frexp, double (*)(double, int*)); + REGISTER_LIBM_SYMBOL(hypot, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(ilogb, int (*)(double)); + REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int)); + REGISTER_LIBM_SYMBOL(lgamma, double (*)(double)); + REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); + REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); + REGISTER_LIBM_SYMBOL(log, double (*)(double)); + REGISTER_LIBM_SYMBOL(log10, double (*)(double)); + REGISTER_LIBM_SYMBOL(log1p, double (*)(double)); + REGISTER_LIBM_SYMBOL(log2, double (*)(double)); + REGISTER_LIBM_SYMBOL(logb, double (*)(double)); + REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); + REGISTER_LIBM_SYMBOL(lround, long (*)(double)); + REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*)); + REGISTER_LIBM_SYMBOL(nan, double (*)(const char*)); + REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double)); + REGISTER_LIBM_SYMBOL(nextafter, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(nexttoward, double (*)(double, long double)); + REGISTER_LIBM_SYMBOL(pow, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(remainder, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*)); + REGISTER_LIBM_SYMBOL(rint, double (*)(double)); + REGISTER_LIBM_SYMBOL(round, double (*)(double)); + REGISTER_LIBM_SYMBOL(scalbln, double (*)(double, long)); + REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int)); + REGISTER_LIBM_SYMBOL(sin, double (*)(double)); + REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); + REGISTER_LIBM_SYMBOL(sinh, double (*)(double)); + REGISTER_LIBM_SYMBOL(sqrt, double (*)(double)); + REGISTER_LIBM_SYMBOL(tan, double (*)(double)); + REGISTER_LIBM_SYMBOL(tanh, double (*)(double)); + REGISTER_LIBM_SYMBOL(tgamma, double (*)(double)); + REGISTER_LIBM_SYMBOL(trunc, double (*)(double)); #undef REGISTER_LIBM_SYMBOL diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index de3cd1544087686fa884fc22382aa4dff5256938..91086fd4a5f68211ef56c2417bb0ef4a38de2cff 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -86,6 +86,9 @@ class DfsHloVisitorBase { virtual Status HandleConvert(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleBitcastConvert(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleCopy(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -208,12 +211,15 @@ class DfsHloVisitorBase { virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0; virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; virtual Status HandleWhile(HloInstructionPtr hlo) = 0; + virtual Status HandleConditional(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0; - virtual Status HandleSend(HloInstructionPtr hlo) = 0; + virtual Status HandleSend(HloInstructionPtr send) = 0; + virtual Status HandleSendDone(HloInstructionPtr send_done) = 0; - virtual Status HandleRecv(HloInstructionPtr hlo) = 0; + virtual Status HandleRecv(HloInstructionPtr recv) = 0; + virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0; virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 7ce88be89dfe0746d9d05ca3d5c788f72ca74cd8..133aa2509405738de8388708b0c61a82023e2738 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -167,12 +167,21 @@ class DfsHloVisitorWithDefaultBase Status HandleWhile(HloInstructionPtr xla_while) override { return DefaultAction(xla_while); } - Status HandleSend(HloInstructionPtr send) override { - return DefaultAction(send); + Status HandleConditional(HloInstructionPtr conditional) override { + return DefaultAction(conditional); } Status HandleRecv(HloInstructionPtr recv) override { return DefaultAction(recv); } + Status HandleRecvDone(HloInstructionPtr recv_done) override { + return DefaultAction(recv_done); + } + Status HandleSend(HloInstructionPtr send) override { + return DefaultAction(send); + } + Status HandleSendDone(HloInstructionPtr send_done) override { + return DefaultAction(send_done); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc new file mode 100644 index 0000000000000000000000000000000000000000..12faed69677cd99c6ed82c8d13dad3138d9461b7 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -0,0 +1,185 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dot_decomposer.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +// TODO(b/69062148) Remove this code when all backends support BatchDot +// natively. +Status DecomposeBatchDot(HloInstruction* dot) { + auto computation = dot->parent(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + const Shape& dot_shape = dot->shape(); + + // ShapeInference should guarantee that lhs/rhs batch dimensions match. + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + const int64 num_batch_dims = dnums.lhs_batch_dimensions_size(); + // Calculate total batch size (note that ShapeInference requires that + // the batch dimensions are most-major). + int64 batch_size = 1; + for (int i = 0; i < num_batch_dims; ++i) { + CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)), + rhs_shape.dimensions(dnums.rhs_batch_dimensions(i))); + batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)); + } + + // Set lhs/rhs_transpose. + CHECK_EQ(1, dnums.lhs_contracting_dimensions_size()); + const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0); + const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0; + + CHECK_EQ(1, dnums.rhs_contracting_dimensions_size()); + const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0); + const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1; + + // Compute R3 and R3 shapes for lhs. + PrimitiveType lhs_type = lhs_shape.element_type(); + const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0); + const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1); + Shape lhs_shape_r3 = + ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols}); + Shape lhs_slice_shape_r3 = + ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols}); + Shape lhs_slice_shape_r2 = + ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols}); + + // Compute R3 and R3 shapes for rhs. + PrimitiveType rhs_type = rhs_shape.element_type(); + const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0); + const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1); + Shape rhs_shape_r3 = + ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols}); + Shape rhs_slice_shape_r3 = + ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols}); + Shape rhs_slice_shape_r2 = + ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols}); + + // Compute R3 and R3 shapes for dot output. + PrimitiveType dot_type = dot_shape.element_type(); + const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0); + const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1); + Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols}); + Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols}); + Shape concat_shape_r3 = + ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols}); + + // Reshape lhs/rhs into R3. + auto lhs_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(lhs_shape_r3, lhs)); + auto rhs_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(rhs_shape_r3, rhs)); + + // Loop through batch size, slicing out required lhs/rhs to compute each Dot. + std::vector output_slices(batch_size); + for (int64 i = 0; i < batch_size; ++i) { + // Slice R3 shape from 'lhs' and reshape to R2. + auto lhs_slice_r3 = computation->AddInstruction( + HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0}, + {i + 1, lhs_rows, lhs_cols}, {1, 1, 1})); + auto lhs_slice_r2 = computation->AddInstruction( + HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3)); + + // Slice R3 shape from 'rhs' and reshape to R2. + auto rhs_slice_r3 = computation->AddInstruction( + HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0}, + {i + 1, rhs_rows, rhs_cols}, {1, 1, 1})); + auto rhs_slice_r2 = computation->AddInstruction( + HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3)); + + // Transpose lhs/rhs (if needed). + if (lhs_transpose) { + Shape lhs_slice_shape_r2_transpose = + ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows}); + lhs_slice_r2 = + computation->AddInstruction(HloInstruction::CreateTranspose( + lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0})); + } + if (rhs_transpose) { + Shape rhs_slice_shape_r2_transpose = + ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows}); + rhs_slice_r2 = + computation->AddInstruction(HloInstruction::CreateTranspose( + rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0})); + } + + // Compute Dot of lhs/rhs R2 slices. + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( + dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); + + // Reshape Dot to R3 so we can concat along batch dimension. + auto dot_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(dot_shape_r3, dot_r2)); + + output_slices[i] = dot_r3; + } + + // Concatenate slices from 'output_slices' along batch dimension. + auto concat = computation->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0)); + // Reshape output 'new_dot' to original dimensions. + auto new_dot = computation->AddInstruction( + HloInstruction::CreateReshape(dot_shape, concat)); + + // Replace all uses of 'dot' in 'computation' with 'new_dot'. + return computation->ReplaceInstruction(dot, new_dot); +} + +} // namespace + +StatusOr DotDecomposer::Run(HloModule* module) { + XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString()); + // Gather all batch Dot operations. + std::vector batch_dots; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) { + batch_dots.push_back(instruction); + } + } + } + // Decompose each batch Dot in 'batch_dots'. + bool changed = false; + for (auto* dot : batch_dots) { + TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + changed = true; + } + XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h new file mode 100644 index 0000000000000000000000000000000000000000..5ff0ab34eac0cd0fbc264b408c57653c944402a6 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// DotDecomposer is a pass which decomposes batch Dot operations into a +// sequence of smaller (R2) Dot operations. +class DotDecomposer : public HloPassInterface { + public: + // Decomposes batch Dot operations when 'decompose_batch_dot' is true. + DotDecomposer(bool decompose_batch_dot = true) + : decompose_batch_dot_(decompose_batch_dot) {} + ~DotDecomposer() = default; + tensorflow::StringPiece name() const override { return "dot_decomposer"; } + + // Run DotDecomposer pass on computations in 'module'. + // Returns whether the 'module' was changed. + StatusOr Run(HloModule* module) override; + + private: + bool decompose_batch_dot_; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index a945657712aae46093cd016d23114f26b8a2d926..7e88bbd63123cd33682bb5ff67761ae5c5bdc98c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -50,11 +50,161 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrCat; +namespace { + +llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, + int64 mantissa_bits, + llvm::IRBuilder<>* ir_builder) { + // Integer and float types for casting and constant generation. + llvm::Type* float_type = x->getType(); + llvm::IntegerType* int_type = ir_builder->getInt32Ty(); + + // Cast the input value to an integer for bitwise manipulation. + llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type); + + if (mantissa_bits < 23) { + // Last remaining mantissa bit. + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. This is + // equal to a base value of 0111... plus one bit if the last remaining + // mantissa bit is 1. + const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; + llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr( + ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), + (23 - mantissa_bits)); + llvm::Value* x_rounding_bias = ir_builder->CreateAdd( + x_last_mantissa_bit, + llvm::ConstantInt::get(int_type, base_rounding_bias)); + + // Add rounding bias, and mask out truncated bits. Note that the case + // where adding the rounding bias overflows into the exponent bits is + // correct; the non-masked mantissa bits will all be zero, and the + // exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias); + x_as_int = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); + } + + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- + // significant bit -- is equal to 1.0f for all exponent sizes. Adding + // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- + // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' + // exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n is + // (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + llvm::Value* x_exponent = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + llvm::Value* x_overflows = ir_builder->CreateICmpUGT( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); + llvm::Value* x_underflows = ir_builder->CreateICmpULE( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); + + // Compute appropriately-signed values of zero and infinity. + llvm::Value* x_signed_zero = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); + llvm::Value* x_signed_inf = ir_builder->CreateOr( + x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + + // Force to zero or infinity if overflow or underflow. (Note that this + // truncates all denormal values to zero, rather than rounding them.) + x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int); + x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int); + } + + // Cast the result back to a floating-point type. + llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type); + + // Correct result for NaN inputs. + // + // The exponent handling will "normalize" NaN values to infinities, which is + // undesirable (except in the case with no mantissa bits, in which case it + // is mandatory). This logic also handles cases where mantissa-rounding + // causes a NaN's mantissa to overflow into the exponent bits, which would + // otherwise create an erroneous zero value. + // + // If the fast-math flags are set to assume no NaNs, the comparison is likely + // to be optimized away, so there's no point in even emitting it. + if (!ir_builder->getFastMathFlags().noNaNs()) { + llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x); + + if (mantissa_bits > 0) { + result = ir_builder->CreateSelect(x_is_nan, x, result); + } else { + result = ir_builder->CreateSelect( + x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); + } + } + return result; +} + +llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, + llvm::IRBuilder<>* ir_builder) { + auto reduced_precision = EmitReducePrecisionFloat( + f32_value, + /*exponent_bits=*/primitive_util::kBFloat16ExponentBits, + /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder); + auto as_int32 = + ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateLShr(as_int32, 16); + auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty()); + return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty()); +} + +llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, + llvm::IRBuilder<>* ir_builder) { + auto as_int16 = + ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty()); + auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateShl(as_int32, 16); + return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy()); +} + +llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, + PrimitiveType from_type, + PrimitiveType to_type, llvm::Module* module, + llvm::IRBuilder<>* ir_builder) { + if (primitive_util::IsSignedIntegralType(from_type)) { + return ir_builder->CreateSIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } else { + CHECK(primitive_util::IsUnsignedIntegralType(from_type) || + from_type == PRED); + return ir_builder->CreateUIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } +} + +} // namespace + StatusOr ElementalIrEmitter::EmitUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { if (op->opcode() == HloOpcode::kCopy) { return operand_value; - } else if (operand_value->getType()->isIntegerTy()) { + } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + op->operand(0)->shape().element_type() == PRED) { return EmitIntegerUnaryOp(op, operand_value); } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) { return EmitComplexUnaryOp(op, operand_value); @@ -79,28 +229,27 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::IsSignedIntegralType(to_type)); } if (primitive_util::IsFloatingPointType(to_type)) { - if (primitive_util::IsSignedIntegralType(from_type)) { - return ir_builder_->CreateSIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); - } - if (primitive_util::IsUnsignedIntegralType(from_type) || - from_type == PRED) { - return ir_builder_->CreateUIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + if (to_type == BF16) { + return EmitF32ToBF16( + EmitIntegralToFloating(operand_value, from_type, F32, module_, + ir_builder_), + ir_builder_); } + return EmitIntegralToFloating(operand_value, from_type, to_type, + module_, ir_builder_); } if (primitive_util::IsComplexType(to_type)) { auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( primitive_util::ComplexComponentType(to_type), module_); if (primitive_util::IsSignedIntegralType(from_type)) { - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateSIToFP(operand_value, to_ir_component_type), nullptr); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateUIToFP(operand_value, to_ir_component_type), nullptr); @@ -110,6 +259,26 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( PrimitiveType_Name(from_type).c_str(), PrimitiveType_Name(to_type).c_str()); } + case HloOpcode::kBitcastConvert: { + PrimitiveType from_type = op->operand(0)->shape().element_type(); + PrimitiveType to_type = op->shape().element_type(); + CHECK(primitive_util::IsIntegralType(from_type)); + if (from_type == to_type) { + return operand_value; + } + if (primitive_util::BitWidth(from_type) == + primitive_util::BitWidth(to_type)) { + return ir_builder_->CreateBitCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + } + return InvalidArgument( + "bitcast conversion from primitive type %s to %s with unequal " + "bit-widths (%u versus %u) ", + PrimitiveType_Name(from_type).c_str(), + PrimitiveType_Name(to_type).c_str(), + primitive_util::BitWidth(from_type), + primitive_util::BitWidth(to_type)); + } case HloOpcode::kAbs: { bool is_signed = primitive_util::IsSignedIntegralType(op->shape().element_type()); @@ -178,15 +347,26 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType to_component_type = primitive_util::ComplexComponentType(to_type); if (from_type == to_component_type) { - return ComposeComplex(op, operand_value, nullptr); + return EmitComposeComplex(op, operand_value, nullptr); } - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFPCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), nullptr); } + if (from_type == BF16) { + TF_RET_CHECK(to_type != BF16); + operand_value = EmitBF16ToF32(operand_value, ir_builder_); + from_type = F32; + if (from_type == to_type) { + return operand_value; + } + } + if (from_type == F32 && to_type == BF16) { + return EmitF32ToBF16(operand_value, ir_builder_); + } if (primitive_util::IsFloatingPointType(to_type)) { return ir_builder_->CreateFPCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); @@ -203,6 +383,26 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType_Name(from_type).c_str(), PrimitiveType_Name(to_type).c_str()); } + case HloOpcode::kBitcastConvert: { + PrimitiveType from_type = op->operand(0)->shape().element_type(); + PrimitiveType to_type = op->shape().element_type(); + CHECK(primitive_util::IsFloatingPointType(from_type)); + if (from_type == to_type) { + return operand_value; + } + if (primitive_util::BitWidth(from_type) == + primitive_util::BitWidth(to_type)) { + return ir_builder_->CreateBitCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + } + return InvalidArgument( + "bitcast conversion from primitive type %s to %s with unequal " + "bit-widths (%u versus %u) ", + PrimitiveType_Name(from_type).c_str(), + PrimitiveType_Name(to_type).c_str(), + primitive_util::BitWidth(from_type), + primitive_util::BitWidth(to_type)); + } case HloOpcode::kExp: return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value}, {operand_value->getType()}, @@ -269,15 +469,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( StatusOr ElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { // TODO(b/65209142): Angle/Log require atan2. - // case HloOpcode::kAngle: // case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -291,24 +484,26 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( primitive_util::ComplexComponentType(to_type); auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); - return ComposeComplex( + return EmitComposeComplex( op, - ir_builder_->CreateFPCast(real(operand_value), to_ir_component_type), - ir_builder_->CreateFPCast(imag(operand_value), to_ir_component_type)); + ir_builder_->CreateFPCast(EmitExtractReal(operand_value), + to_ir_component_type), + ir_builder_->CreateFPCast(EmitExtractImag(operand_value), + to_ir_component_type)); } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) auto exp_a = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::exp, {real(operand_value)}, - {real(operand_value)->getType()}, ir_builder_); + llvm::Intrinsic::exp, {EmitExtractReal(operand_value)}, + {EmitExtractReal(operand_value)->getType()}, ir_builder_); auto cos_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::cos, {imag(operand_value)}, - {imag(operand_value)->getType()}, ir_builder_); + llvm::Intrinsic::cos, {EmitExtractImag(operand_value)}, + {EmitExtractImag(operand_value)->getType()}, ir_builder_); auto sin_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::sin, {imag(operand_value)}, - {imag(operand_value)->getType()}, ir_builder_); - return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); + llvm::Intrinsic::sin, {EmitExtractImag(operand_value)}, + {EmitExtractImag(operand_value)->getType()}, ir_builder_); + return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); } case HloOpcode::kCos: { // cos(z) = .5(e^(iz) + e^(-iz)) @@ -318,8 +513,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( // cos(-x) = cos(x) and sin(-x) = -sin(x), so // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i)) // = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); auto type = a->getType(); auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, {type}, ir_builder_); @@ -331,7 +526,7 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( {type}, ir_builder_); auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, {type}, ir_builder_); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), @@ -348,8 +543,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( // cos(-x) = cos(x) and sin(-x) = -sin(x), so // = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a))) // = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); auto type = a->getType(); auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, {type}, ir_builder_); @@ -361,7 +556,7 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( {type}, ir_builder_); auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, {type}, ir_builder_); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), @@ -370,33 +565,40 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( } case HloOpcode::kAbs: { auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(operand_value), real(operand_value)), - ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + ir_builder_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + ir_builder_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); } case HloOpcode::kSign: { // Sign(c) = c / |c| auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(operand_value), real(operand_value)), - ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + ir_builder_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + ir_builder_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); auto cplx_abs = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero); return ir_builder_->CreateSelect( - oeq, ComposeComplex(op, zero, zero), - ComposeComplex( - op, ir_builder_->CreateFDiv(real(operand_value), cplx_abs), - ir_builder_->CreateFDiv(imag(operand_value), cplx_abs))); + oeq, EmitComposeComplex(op, zero, zero), + EmitComposeComplex( + op, + ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), + ir_builder_->CreateFDiv(EmitExtractImag(operand_value), + cplx_abs))); } case HloOpcode::kNegate: - return ComposeComplex(op, ir_builder_->CreateFNeg(real(operand_value)), - ir_builder_->CreateFNeg(imag(operand_value))); + return EmitComposeComplex( + op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)), + ir_builder_->CreateFNeg(EmitExtractImag(operand_value))); case HloOpcode::kReal: - return real(operand_value); + return EmitExtractReal(operand_value); case HloOpcode::kImag: - return imag(operand_value); + return EmitExtractImag(operand_value); default: return Unimplemented("unary complex op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -407,7 +609,8 @@ StatusOr ElementalIrEmitter::EmitBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { PrimitiveType operand_type = op->operand(0)->shape().element_type(); - if (lhs_value->getType()->isIntegerTy()) { + if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + operand_type == PRED) { return EmitIntegerBinaryOp( op, lhs_value, rhs_value, primitive_util::IsSignedIntegralType(operand_type)); @@ -424,7 +627,7 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( switch (op->opcode()) { // case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support case HloOpcode::kComplex: - return ComposeComplex(op, lhs_value, rhs_value); + return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: return ir_builder_->CreateFAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: @@ -479,54 +682,66 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( StatusOr ElementalIrEmitter::EmitComplexBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { case HloOpcode::kAdd: - return ComposeComplex( - op, ir_builder_->CreateFAdd(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFAdd(imag(lhs_value), imag(rhs_value))); + return EmitComposeComplex( + op, + ir_builder_->CreateFAdd(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFAdd(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kSubtract: - return ComposeComplex( - op, ir_builder_->CreateFSub(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFSub(imag(lhs_value), imag(rhs_value))); + return EmitComposeComplex( + op, + ir_builder_->CreateFSub(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFSub(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kMultiply: - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFSub( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)))); case HloOpcode::kDivide: { // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) auto rhs_sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(rhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(rhs_value), imag(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(rhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(rhs_value), + EmitExtractImag(rhs_value))); auto type = rhs_sum_sq->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero); + auto real_inf_or_nan = + ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero); + auto imag_inf_or_nan = + ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero); return ir_builder_->CreateSelect( - oeq, ComposeComplex(op, llvm::ConstantFP::getInfinity(type), zero), - ComposeComplex( + oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), + EmitComposeComplex( op, ir_builder_->CreateFDiv( ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), - imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq), ir_builder_->CreateFDiv( ir_builder_->CreateFSub( - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(real(lhs_value), - imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq))); } // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered @@ -538,16 +753,20 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // matches C++'s semantics. case HloOpcode::kEq: return ir_builder_->CreateAnd( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, real(lhs_value), - real(rhs_value), ir_builder_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, imag(lhs_value), - imag(rhs_value), ir_builder_)); + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), ir_builder_)); case HloOpcode::kNe: return ir_builder_->CreateOr( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, real(lhs_value), - real(rhs_value), ir_builder_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, imag(lhs_value), - imag(rhs_value), ir_builder_)); + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), ir_builder_)); // TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic // case HloOpcode::kPower: @@ -659,111 +878,9 @@ StatusOr ElementalIrEmitter::EmitReducePrecision( if (hlo->operand(0)->shape().element_type() != F32) { return Unimplemented("reduce-precision only implemented for F32"); } - - // Integer and float types for casting and constant generation. - llvm::Type* float_type = x->getType(); - llvm::IntegerType* int_type = ir_builder_->getInt32Ty(); - - // Cast the input value to an integer for bitwise manipulation. - llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type); - - if (hlo->mantissa_bits() < 23) { - // Last remaining mantissa bit. - const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits()); - - // Compute rounding bias for round-to-nearest with ties to even. This is - // equal to a base value of 0111... plus one bit if the last remaining - // mantissa bit is 1. - const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; - llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr( - ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), - (23 - hlo->mantissa_bits())); - llvm::Value* x_rounding_bias = ir_builder_->CreateAdd( - x_last_mantissa_bit, - llvm::ConstantInt::get(int_type, base_rounding_bias)); - - // Add rounding bias, and mask out truncated bits. Note that the case - // where adding the rounding bias overflows into the exponent bits is - // correct; the non-masked mantissa bits will all be zero, and the - // exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias); - x_as_int = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); - } - - if (hlo->exponent_bits() < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; - - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- - // significant bit -- is equal to 1.0f for all exponent sizes. Adding - // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- - // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' - // exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n is - // (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (hlo->exponent_bits() - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; - - // Do we overflow or underflow? - llvm::Value* x_exponent = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); - llvm::Value* x_overflows = ir_builder_->CreateICmpUGT( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); - llvm::Value* x_underflows = ir_builder_->CreateICmpULE( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); - - // Compute appropriately-signed values of zero and infinity. - llvm::Value* x_signed_zero = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); - llvm::Value* x_signed_inf = ir_builder_->CreateOr( - x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); - - // Force to zero or infinity if overflow or underflow. (Note that this - // truncates all denormal values to zero, rather than rounding them.) - x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int); - x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int); - } - - // Cast the result back to a floating-point type. - llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type); - - // Correct result for NaN inputs. - // - // The exponent handling will "normalize" NaN values to infinities, which is - // undesirable (except in the case with no mantissa bits, in which case it - // is mandatory). This logic also handles cases where mantissa-rounding - // causes a NaN's mantissa to overflow into the exponent bits, which would - // otherwise create an erroneous zero value. - // - // If the fast-math flags are set to assume no NaNs, the comparison is likely - // to be optimized away, so there's no point in even emitting it. - if (!ir_builder_->getFastMathFlags().noNaNs()) { - llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x); - - if (hlo->mantissa_bits() > 0) { - result = ir_builder_->CreateSelect(x_is_nan, x, result); - } else { - result = ir_builder_->CreateSelect( - x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); - } - } - return result; + return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(), + /*mantissa_bits=*/hlo->mantissa_bits(), + ir_builder_); } StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( @@ -847,7 +964,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If no implicit broadcast is needed for this operand, returns the target // index as the source index. - if (ShapeUtil::Compatible(operand_shape, hlo.shape())) { + if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) { return target_index; } @@ -1055,6 +1172,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kRoundNearestAfz: case HloOpcode::kCeil: case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: @@ -1063,11 +1181,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kNegate: + case HloOpcode::kNot: case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kTanh: - case HloOpcode::kNot: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, @@ -1076,6 +1194,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return EmitUnaryOp(hlo, operand_value); }; case HloOpcode::kAdd: + case HloOpcode::kAnd: case HloOpcode::kAtan2: case HloOpcode::kComplex: case HloOpcode::kDivide: @@ -1088,14 +1207,13 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kMinimum: case HloOpcode::kMultiply: case HloOpcode::kNe: + case HloOpcode::kOr: case HloOpcode::kPower: case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: + case HloOpcode::kSubtract: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { const HloInstruction* lhs = hlo->operand(0); @@ -1565,25 +1683,25 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; if (primitive_util::IsComplexType(primitive_type)) { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value))); next_accumulator = ir_builder_->CreateInsertValue( current_accumulator, - ir_builder_->CreateFAdd(real(current_accumulator), product_real), + ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), + product_real), {0}); next_accumulator = ir_builder_->CreateInsertValue( next_accumulator, - ir_builder_->CreateFAdd(imag(current_accumulator), product_imag), + ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), + product_imag), {1}); } else if (primitive_util::IsFloatingPointType(primitive_type)) { next_accumulator = ir_builder_->CreateFAdd( @@ -1607,9 +1725,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( } } -llvm::Value* ElementalIrEmitter::ComposeComplex(const HloInstruction* op, - llvm::Value* real, - llvm::Value* imag) const { +llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { + return ir_builder_->CreateExtractValue(value, {0}); +} + +llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { + return ir_builder_->CreateExtractValue(value, {1}); +} + +llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, + llvm::Value* real, + llvm::Value* imag) const { auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto complex = ir_builder_->CreateInsertValue( diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 9d32436e38fa2fb3e27d09f01b860cd2edf2c8ac..cccb498f82936283a215370787907b293827ff2d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -95,6 +95,13 @@ class ElementalIrEmitter { virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x) const; + virtual llvm::Value* EmitExtractReal(llvm::Value* value) const; + virtual llvm::Value* EmitExtractImag(llvm::Value* value) const; + + // Composes a complex struct. imag may be nullptr for simple cast operations. + llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, + llvm::Value* imag) const; + // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its // `operand_no`-th operand. @@ -117,11 +124,6 @@ class ElementalIrEmitter { // compiled executable outside of the HLO code itself. const HloModuleConfig& hlo_module_config_; - protected: - // Composes a complex struct. imag may be nullptr for simple cast operations. - llvm::Value* ComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const; - private: // Returns a ElementGenerator for a RNG HloInstruction. llvm_ir::ElementGenerator MakeRngElementGenerator( diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 9c96d9eb30b5f9e51b7f5d82391c6b9f366898d6..ad5d5ead00eaae558912537a17ea53394a1a20a0 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -52,7 +52,7 @@ Executable::ExecuteOnStreams( } for (const auto& options : run_options) { TF_RET_CHECK(options.stream() != nullptr); - options.stream()->BlockHostUntilDone(); + TF_RETURN_IF_ERROR(options.stream()->BlockHostUntilDone()); } return return_values; } diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 7e0d182b365c35788195e70dc35c3923ed8991bb..cb9ee47dc6885789bbf9a718fce9aebd6fa81fc8 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -44,8 +44,15 @@ namespace xla { // interface that is used for launching compiled programs across platforms. class Executable { public: - explicit Executable(std::unique_ptr hlo_module) - : hlo_module_(std::move(hlo_module)) {} + explicit Executable(std::unique_ptr hlo_module, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : hlo_module_(std::move(hlo_module)), + hlo_profile_printer_(std::move(hlo_profile_printer)), + hlo_profile_index_map_(std::move(hlo_profile_index_map)) { + CHECK_EQ(hlo_profile_printer_.get() == nullptr, + hlo_profile_index_map_.get() == nullptr); + } virtual ~Executable() {} // Enqueues the compilation result on the provided stream, passing the given @@ -123,12 +130,20 @@ class Executable { "Equality test on this executable is not implemented."); } + const HloProfilePrinter& hlo_profile_printer() const { + CHECK(hlo_profiling_enabled()); + return *hlo_profile_printer_; + } + + const HloProfileIndexMap& hlo_profile_index_map() const { + CHECK(hlo_profiling_enabled()); + return *hlo_profile_index_map_; + } + // Returns whether this executable was compiled with HLO profilings support // enabled. If not, the caller should not expect an hlo_execution_profile // passed to ExecuteOnStream above to be populated during execution. - bool hlo_profiling_enabled() const { - return hlo_module_->config().hlo_profiling_enabled(); - } + bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } const HloModule& module() const { return *hlo_module_; } @@ -160,10 +175,6 @@ class Executable { static Status DumpToDirectory(const string& directory_path, string filename, const SessionModule& session_module); - // Returns a cost analysis object appropriate for the platform on which this - // executable can run. - virtual std::unique_ptr CreateCostAnalysis() const = 0; - protected: mutable tensorflow::mutex mutex_; @@ -181,6 +192,9 @@ class Executable { // Execution count, used to generate a unique filename for each dumped // execution. int64 execution_count_ = 0; + + std::unique_ptr hlo_profile_printer_; + std::unique_ptr hlo_profile_index_map_; }; template @@ -197,25 +211,31 @@ StatusOr Executable::ExecuteOnStreamWrapper( VLOG(1) << "enqueueing executable on stream..."; // If the profiling flag isn't enabled, we pass nullptr as the profile to // indicate profiling is not requested. - HloExecutionProfile hlo_execution_profile; - HloExecutionProfile* profile_ptr = + std::unique_ptr profile_ptr = module_config().debug_options().xla_hlo_profile() && hlo_profiling_enabled() - ? &hlo_execution_profile + ? MakeUnique(&hlo_profile_printer(), + &hlo_profile_index_map()) : nullptr; - auto return_value = ExecuteOnStream(run_options, arguments, profile_ptr); + auto return_value = + ExecuteOnStream(run_options, arguments, profile_ptr.get()); if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; - stream->ThenStopTimer(timer.get()).BlockHostUntilDone(); + stream->ThenStopTimer(timer.get()); + SE_CHECK_OK(stream->BlockHostUntilDone()); VLOG(1) << "done with block-host-until-done"; // Merge in run-time profile information from execution_profile. profile->MergeFrom(execution_profile()); // Overall execution time (in nanoseconds) from the executor timer. - profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); + if (stream->ok()) { + // Don't read timer->Nanoseconds() if the stream isn't OK -- that's + // illegal. + profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); + } // TODO(b/28123297): On GPU we end up including transfer time in // the compute time this way. Instead, we should get the correct @@ -232,24 +252,11 @@ StatusOr Executable::ExecuteOnStreamWrapper( } if (profile_ptr != nullptr) { - std::unordered_set profiled_computations = - profile_ptr->profiled_computations(); - // To ensure we have print the profiles in a stable order, iterate over the - // computations in post order. - std::list all_computations = - module().MakeComputationPostOrder(); - for (xla::HloComputation* computation : all_computations) { - if (profiled_computations.count(computation) > 0) { - string profile_string = profile_ptr->ToString( - *computation, stream->parent()->GetDeviceDescription(), - CreateCostAnalysis().get()); - if (!profile_string.empty()) { - XLA_LOG_LINES(tensorflow::INFO, profile_string); - } - } - } + XLA_LOG_LINES( + tensorflow::INFO, + profile_ptr->ToString(stream->parent()->GetDeviceDescription())); hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", - profile_ptr); + profile_ptr.get()); } return return_value; diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index c225e62e3e11d2d01251b0f92272b0949eff8af1..2f0b9ed2bd98fbea4e67c0a30d5aa41ff6a06979 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -39,9 +39,7 @@ AsyncExecution::AsyncExecution(Backend* backend, tensorflow::Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { - if (!stream->BlockHostUntilDone()) { - return InternalError("failed to block until done"); - } + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); } return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index dfba22a6c4c5cf071c2cd8621643b8da6587ee3b..2b6caa149439a86d6d047605099bc3ff7b295a8e 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -26,7 +26,10 @@ namespace xla { namespace { -// Helper to replace the called computation at a while- or call-instruction. +// Helper to replace the called computation at a while-, call-, or +// conditional-instruction. This function replaces exactly one instance of +// 'computation' with 'new_computation' even if 'instruction' calls +// 'computation' more than once. void ReplaceCalledComputation(HloInstruction* instruction, HloComputation* computation, HloComputation* new_computation) { @@ -45,6 +48,15 @@ void ReplaceCalledComputation(HloInstruction* instruction, instruction->set_to_apply(new_computation); break; } + case HloOpcode::kConditional: { + if (computation == instruction->true_computation()) { + instruction->set_true_computation(new_computation); + } else { + CHECK_EQ(computation, instruction->false_computation()); + instruction->set_false_computation(new_computation); + } + break; + } default: LOG(FATAL) << "unexpected opcode: " << HloOpcodeString(instruction->opcode()); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index a68e90b7d009890012f94baa790d911871c9c960..d3854b40de3572a60df1ad99d8a4589f59ad7194 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -223,5 +223,35 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { EXPECT_EQ(1, b_node.caller_callsites().size()); } +TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { + auto module = CreateNewModule(); + HloComputation* sub_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + + // Create entry computation, which is a conditional that has the same + // computation in the true and false branch. + HloComputation::Builder builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + builder.AddInstruction(HloInstruction::CreateConditional( + kScalarShape, pred, constant1, sub_computation, constant2, + sub_computation)); + module->AddEntryComputation(builder.Build()); + EXPECT_EQ(2, module->computation_count()); + + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + EXPECT_TRUE(result); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + // The true and false computations must now be different. + EXPECT_EQ(3, module->computation_count()); + + const CallGraphNode& sub_node = call_graph->GetNode(sub_computation); + EXPECT_EQ(1, sub_node.caller_callsites().size()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index b4fbed1562945adeb52a9471453ed4fee0e35180..74aa77b4f165be76fbc0a8aa1a4a7e90a8e9acec 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -103,8 +104,7 @@ GenericTransferManager::ShallowCopyTupleFromDevice( // a vector of void* pointers. std::vector element_pointers(ShapeUtil::TupleElementCount(shape), nullptr); - int64 tuple_size = - ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + int64 tuple_size = ShapeUtil::ByteSizeOf(shape, pointer_size_); auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, element_pointers.data()); if (!copy_status.ok()) { @@ -121,9 +121,8 @@ GenericTransferManager::ShallowCopyTupleFromDevice( !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { return FailedPrecondition("tuple contains nullptr at element %lu", i); } - int64 buffer_size = ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), - /*pointer_size=*/sizeof(void*)); - destination.emplace_back(element_pointers[i], buffer_size); + destination.emplace_back(element_pointers[i], + GetByteSizeRequirement(shape.tuple_shapes(i))); } return std::move(destination); } @@ -138,11 +137,79 @@ Status GenericTransferManager::WriteTuplePointersToDevice( for (const se::DeviceMemoryBase& element : elements) { element_pointers.push_back(element.opaque()); } - int64 tuple_size = - ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), + element_pointers.data(), region); +} + +StatusOr> +GenericTransferManager::TransferLiteralFromDevice( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { + VLOG(2) << "transferring literal from device ordinal " + << executor->device_ordinal() << "; device shape: " + << ShapeUtil::HumanStringWithLayout(device_buffer.shape()) + << "; opaque: " << device_buffer.buffer(/*index=*/{}).opaque(); + TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + + std::unique_ptr literal = + Literal::CreateFromShape(device_buffer.shape()); + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + device_buffer.shape(), + [&](const Shape& subshape, const ShapeIndex& index) -> Status { + if (!ShapeUtil::IsTuple(subshape)) { + TF_RETURN_IF_ERROR(TransferBufferFromDevice( + executor, + /*source=*/device_buffer.buffer(index), + /*size=*/GetByteSizeRequirement(subshape), + /*destination=*/ + literal->GetSubliteral(index).MutableInternalData())); + } + + return Status::OK(); + })); + return std::move(literal); +} + +Status GenericTransferManager::TransferLiteralToDevice( + se::StreamExecutor* executor, const Literal& literal, + const ShapedBuffer& device_buffer) { + const Shape& shape = literal.shape(); + VLOG(2) << "transferring literal shape to device: " + << ShapeUtil::HumanString(shape) << "; device location: " + << device_buffer.buffer(/*index=*/{}).opaque(); + + TF_RET_CHECK(ShapeUtil::Compatible(literal.shape(), device_buffer.shape())); + TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + + TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer)); - return TransferBufferToDevice(executor, tuple_size, element_pointers.data(), - region); + return ShapeUtil::ForEachSubshapeWithStatus( + device_buffer.shape(), + [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { + se::DeviceMemoryBase device_memory = device_buffer.buffer(index); + if (ShapeUtil::IsArray(device_subshape)) { + TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == + device_memory.size()); + // Element is array-shaped: transfer array data to device buffer. + const Literal& subliteral = literal.GetSubliteral(index); + std::unique_ptr relayed_out_literal; + const void* source; + if (LayoutUtil::Equal(device_subshape.layout(), + subliteral.shape().layout())) { + source = subliteral.InternalData(); + } else { + // Relayout data before transferring. + relayed_out_literal = subliteral.Relayout(device_subshape.layout(), + /*shape_index=*/{}); + source = relayed_out_literal->InternalData(); + } + return TransferBufferToDevice( + executor, + /*size=*/GetByteSizeRequirement(device_subshape), source, + &device_memory); + } + return Status::OK(); + }); } Status GenericTransferManager::TransferLiteralToDevice( @@ -198,7 +265,7 @@ Status GenericTransferManager::ResetDevices( } int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) const { - return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + return ShapeUtil::ByteSizeOf(shape, pointer_size_); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index ef9a50676a4171b56e8a77d2dedc05b1580e5ea5..50dca6aec5012f0b02cb54846b622f008600e48e 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -52,6 +52,14 @@ class GenericTransferManager : public TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal, perftools::gputools::DeviceMemoryBase* destination) override; + StatusOr> TransferLiteralFromDevice( + perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer) override; + + Status TransferLiteralToDevice(perftools::gputools::StreamExecutor* executor, + const Literal& literal, + const ShapedBuffer& device_buffer) override; + Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, @@ -71,6 +79,9 @@ class GenericTransferManager : public TransferManager { const perftools::gputools::DeviceMemoryBase& source, const Shape& shape) override; + int64 GetByteSizeRequirement(const Shape& shape) const override; + + protected: Status WriteTuplePointersToDevice( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice @@ -78,8 +89,6 @@ class GenericTransferManager : public TransferManager { const Shape& shape, perftools::gputools::DeviceMemoryBase* region) override; - int64 GetByteSizeRequirement(const Shape& shape) const override; - private: // The platform this transfer manager targets. const perftools::gputools::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 364b76b93c288f13f2bf447cebfc25f705d77826..4a72f87efdd92497ac4c2cd73b56c4990ed5b04c 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -343,15 +343,16 @@ tf_cc_test( ) cc_library( - name = "copy_insertion", - srcs = ["copy_insertion.cc"], - hdrs = ["copy_insertion.h"], + name = "gpu_copy_insertion", + srcs = ["gpu_copy_insertion.cc"], + hdrs = ["gpu_copy_insertion.h"], deps = [ ":ir_emission_utils", + "//tensorflow/compiler/xla/service:call_graph", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:logical_buffer", - "//tensorflow/compiler/xla/service:tuple_points_to_analysis", + "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", + "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", ], ) @@ -427,14 +428,14 @@ cc_library( hdrs = ["gpu_compiler.h"], deps = [ ":convolution_folding", - ":copy_insertion", ":fusion_merger", + ":gpu_copy_insertion", ":gpu_executable", + ":gpu_layout_assignment", ":hlo_schedule", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", - ":layout_assignment", ":pad_insertion", ":partition_assignment", ":stream_assignment", @@ -448,6 +449,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -491,9 +493,9 @@ cc_library( ) cc_library( - name = "layout_assignment", - srcs = ["layout_assignment.cc"], - hdrs = ["layout_assignment.h"], + name = "gpu_layout_assignment", + srcs = ["gpu_layout_assignment.cc"], + hdrs = ["gpu_layout_assignment.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", @@ -507,10 +509,10 @@ cc_library( ) tf_cc_test( - name = "layout_assignment_test", - srcs = ["layout_assignment_test.cc"], + name = "gpu_layout_assignment_test", + srcs = ["gpu_layout_assignment_test.cc"], deps = [ - ":layout_assignment", + ":gpu_layout_assignment", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -574,11 +576,14 @@ tf_cc_test( deps = [ ":instruction_fusion", ":while_transformer", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc index 5aaf072f9d2c95e2fff70a1c5337432a12a1aa48..f198c4c08e93277b3a14a32d906b8083a94a8a2c 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -55,28 +55,20 @@ MatchBackwardFilter(HloInstruction* conv) { // v v // Convolution // conv - // | - // v - // Transpose (optional if identity transposition) CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); - // If the forward convolution is followed by a transpose, we can fuse the - // transpose into the backward convolution as well. - HloInstruction* transpose = nullptr; - if (conv->user_count() == 1) { - HloInstruction* single_user = *conv->users().begin(); - if (single_user->opcode() == HloOpcode::kTranspose) { - transpose = single_user; - } - } // Step 2: match paddings and dimension numbers of the forward convolution. const ConvolutionDimensionNumbers& conv_dnums = conv->convolution_dimension_numbers(); auto input_batch_dim = conv_dnums.input_batch_dimension(); auto input_feature_dim = conv_dnums.input_feature_dimension(); + auto input_spatial_dims = conv_dnums.input_spatial_dimensions(); + auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension(); + auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension(); + auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions(); auto output_batch_dim = conv_dnums.output_batch_dimension(); auto output_feature_dim = conv_dnums.output_feature_dimension(); - auto spatial_dims = conv_dnums.spatial_dimensions(); + auto output_spatial_dims = conv_dnums.output_spatial_dimensions(); for (const WindowDimension& window_dim : conv->window().dimensions()) { if (window_dim.stride() != 1) { @@ -97,7 +89,8 @@ MatchBackwardFilter(HloInstruction* conv) { } // Padding high will be checked in Step 3. } - if (transpose == nullptr && !window_util::HasWindowDilation(conv->window())) { + if (input_batch_dim == output_batch_dim && + !window_util::HasWindowDilation(conv->window())) { VLOG(1) << conv->ToString() << " is a regular forward convolution. No need " "to fold it to a backward filter convolution."; @@ -108,11 +101,11 @@ MatchBackwardFilter(HloInstruction* conv) { // // Compute the window of the backward convolution. Window backward_conv_window; - for (int i = 0; i < spatial_dims.size(); ++i) { + for (int i = 0; i < input_spatial_dims.size(); ++i) { WindowDimension* dim = backward_conv_window.add_dimensions(); // The window size of the backward convolution equals the output size of the // forward convolution. - int64 filter_size = conv->shape().dimensions(spatial_dims[i]); + int64 filter_size = conv->shape().dimensions(output_spatial_dims[i]); dim->set_size(filter_size); // The window stride equals the window dilation of the forward convolution. dim->set_stride(conv->window().dimensions(i).window_dilation()); @@ -120,7 +113,8 @@ MatchBackwardFilter(HloInstruction* conv) { // activations. dim->set_padding_low(conv->window().dimensions(i).padding_low()); - int64 input_size = conv->operand(0)->shape().dimensions(spatial_dims[i]); + int64 input_size = + conv->operand(0)->shape().dimensions(input_spatial_dims[i]); int64 output_size = conv->window().dimensions(i).size(); // Compute the range of the amount of valid high padding. We first compute // min_padding_high, the amount of padding on the right/bottom to ensure the @@ -167,50 +161,32 @@ MatchBackwardFilter(HloInstruction* conv) { } } - // To make future HLO passes easier, we canonicalize the fused expression by - // adding an identity transposition if it's omitted in the pattern. - if (transpose == nullptr) { - // Create an identity transposition with the same rank as the forward - // convolution. - HloComputation* parent_computation = conv->parent(); - std::vector transpose_dimensions(ShapeUtil::Rank(conv->shape())); - std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 0); - transpose = - parent_computation->AddInstruction(HloInstruction::CreateTranspose( - conv->shape(), conv, transpose_dimensions)); - TF_CHECK_OK(conv->ReplaceAllUsesWith(transpose)); - } - // Restore the dimension numbers of the backward convolution from the forward // convolution. The two activation dimensions are reversed (batch and // feature). ConvolutionDimensionNumbers backward_conv_dnums; backward_conv_dnums.set_input_batch_dimension(input_feature_dim); backward_conv_dnums.set_input_feature_dimension(input_batch_dim); - backward_conv_dnums.set_output_batch_dimension(output_feature_dim); - backward_conv_dnums.set_output_feature_dimension(output_batch_dim); - for (int i = 0; i < spatial_dims.size(); ++i) { - backward_conv_dnums.add_spatial_dimensions(spatial_dims[i]); + for (int i = 0; i < input_spatial_dims.size(); ++i) { + backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]); + } + backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim); + backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim); + for (int i = 0; i < kernel_spatial_dims.size(); ++i) { + backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]); } // The dimension numbering of the output of the forward convolution (before // transposition) is the same as that of the activations (according to the // semantics of kConvolution). The batch dimension of the activations should // be treated as the input feature dimension, and the feature dimension should // be treated as the output feature. - // - // The output of the forward convolution needs to be transposed to fit into - // the dimension numbering of the weight gradients. This transposition maps - // dimension i to PositionInContainer(transpose->dimensions(), i). - backward_conv_dnums.set_kernel_input_feature_dimension( - PositionInContainer(transpose->dimensions(), output_batch_dim)); - backward_conv_dnums.set_kernel_output_feature_dimension( - PositionInContainer(transpose->dimensions(), output_feature_dim)); - for (int i = 0; i < spatial_dims.size(); ++i) { - backward_conv_dnums.add_kernel_spatial_dimensions( - PositionInContainer(transpose->dimensions(), spatial_dims[i])); + backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim); + backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim); + for (int i = 0; i < output_spatial_dims.size(); ++i) { + backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); } - return std::make_tuple(true, std::vector({transpose, conv}), + return std::make_tuple(true, std::vector({conv}), backward_conv_window, backward_conv_dnums); } @@ -272,12 +248,14 @@ MatchBackwardInput(HloInstruction* conv) { } } - const auto& spatial_dims = dnums.spatial_dimensions(); - CHECK_EQ(conv->window().dimensions().size(), spatial_dims.size()); + const auto& input_spatial_dims = dnums.input_spatial_dimensions(); + const auto& output_spatial_dims = dnums.output_spatial_dimensions(); + CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size()); + CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size()); const Window& old_window = conv->window(); Window new_window = old_window; - for (size_t i = 0; i < spatial_dims.size(); ++i) { + for (size_t i = 0; i < input_spatial_dims.size(); ++i) { // Restore backward convolution's padding config from the matched pattern. // See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc // for how we convert backward input convolution to a variant of forward @@ -310,8 +288,9 @@ MatchBackwardInput(HloInstruction* conv) { // end at the border. The maximum amount (max_padding_high) equals // min_padding_high+stride-1 -- max_padding_high+1 would cause the output // size to change. - auto unpadded_input_size = conv->shape().dimensions(spatial_dims[i]); - auto output_size = conv->operand(0)->shape().dimensions(spatial_dims[i]); + auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]); + auto output_size = + conv->operand(0)->shape().dimensions(input_spatial_dims[i]); auto padded_input_size = kernel_size + dim->stride() * (output_size - 1); auto total_pad_size = padded_input_size - unpadded_input_size; auto min_padding_high = total_pad_size - backward_padding_low; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc index 19b122ba0603b4ec08d73e05da4c2ae11a760553..34e6bdb117d47a3d7e1eb3bae5806e130e94ea79 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc @@ -46,23 +46,27 @@ class ConvolutionFoldingTest : public HloTestBase { // // TODO(jingyue): Add more tests on NCHW input order which TF also supports. tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); - tf_default_dnums_for_backward_filter_.set_output_batch_dimension(3); tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); - tf_default_dnums_for_backward_filter_.set_output_feature_dimension(0); - tf_default_dnums_for_backward_filter_.add_spatial_dimensions(1); - tf_default_dnums_for_backward_filter_.add_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2); tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension( 3); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2); + tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3); tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); tf_default_dnums_for_backward_input_.set_input_feature_dimension(3); tf_default_dnums_for_backward_input_.set_output_feature_dimension(3); - tf_default_dnums_for_backward_input_.add_spatial_dimensions(1); - tf_default_dnums_for_backward_input_.add_spatial_dimensions(2); + tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1); + tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1); + tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2); + tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2); tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3); tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2); tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0); @@ -82,7 +86,7 @@ class ConvolutionFoldingTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) { +TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -132,7 +136,7 @@ TEST_F(ConvolutionFoldingTest, auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + EXPECT_TRUE(FoldConvolution(module.get())); } // Extracted from block35 training. @@ -151,13 +155,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { conv_window.mutable_dimensions(i)->set_padding_low(1); conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -185,13 +185,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { conv_window.mutable_dimensions(i)->set_padding_high(-1); conv_window.mutable_dimensions(i)->set_window_dilation(2); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -218,13 +214,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { // Uneven padding: padding_low=0, padding_high=1 conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {2, 2, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -258,8 +250,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { conv_dnums.set_output_batch_dimension(0); conv_dnums.set_input_feature_dimension(1); conv_dnums.set_output_feature_dimension(1); - conv_dnums.add_spatial_dimensions(2); - conv_dnums.add_spatial_dimensions(3); + conv_dnums.add_input_spatial_dimensions(2); + conv_dnums.add_output_spatial_dimensions(2); + conv_dnums.add_input_spatial_dimensions(3); + conv_dnums.add_output_spatial_dimensions(3); conv_dnums.set_kernel_input_feature_dimension(0); conv_dnums.set_kernel_output_feature_dimension(1); conv_dnums.add_kernel_spatial_dimensions(2); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 536b96dcf620e908e25a775bc2efb57ba5f5edd6..899cc5c83b99f1bb6154f883ca17871863e1f457 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -28,12 +29,12 @@ namespace se = ::perftools::gputools; namespace xla { namespace gpu { +using se::dnn::AlgorithmDesc; using se::dnn::BatchDescriptor; using se::dnn::ConvolutionDescriptor; using se::dnn::DataLayout; using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; -using se::dnn::AlgorithmDesc; ConvolveScratchAllocator::ConvolveScratchAllocator( int device_ordinal, DeviceMemoryAllocator* memory_allocator) @@ -130,8 +131,9 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( const int effective_num_dimensions = std::max(2, num_dimensions); CHECK_EQ(F32, output_shape_.element_type()); - CHECK_EQ(num_dimensions, dim_nums_.spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dim_nums_.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dim_nums_.kernel_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dim_nums_.output_spatial_dimensions_size()); for (const WindowDimension& dim : window_.dimensions()) { CHECK_EQ(dim.padding_low(), dim.padding_high()); } @@ -147,7 +149,7 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( // Note that the dimensions are reversed. The same holds below. input_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), - input_shape_.dimensions(dim_nums_.spatial_dimensions(dim))); + input_shape_.dimensions(dim_nums_.input_spatial_dimensions(dim))); } FilterDescriptor filter_descriptor(effective_num_dimensions); @@ -181,7 +183,7 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( for (int dim = 0; dim < num_dimensions; ++dim) { output_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), - output_shape_.dimensions(dim_nums_.spatial_dimensions(dim))); + output_shape_.dimensions(dim_nums_.output_spatial_dimensions(dim))); } // Add a singleton dimension in the 1D convolution case. @@ -257,28 +259,52 @@ tensorflow::Status ConvolutionThunk::Convolve( } std::vector ConvolutionThunk::GetAlgorithms( - se::StreamExecutor* stream_exec) const { + bool with_winograd_nonfused, se::StreamExecutor* stream_exec) const { std::vector algorithms; - // TODO(yangzihao): Currently disable the use of winograd nonfused in XLA - // by default. Should send in conv parameters and enable it when - // ShouldIncludeWinogradNonfusedAlgo() returns true. switch (convolution_kind_) { case ConvolutionKind::kBackwardFilter: CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( - /*with_winograd_nonfused=*/false, &algorithms)); + with_winograd_nonfused, &algorithms)); break; case ConvolutionKind::kBackwardInput: CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( - /*with_winograd_nonfused=*/false, &algorithms)); + with_winograd_nonfused, &algorithms)); break; case ConvolutionKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(/*with_winograd_nonfused=*/false, + CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, &algorithms)); break; } return algorithms; } +static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) { + if (algo.tensor_ops_enabled()) { + return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + } + return tensorflow::strings::StrCat(algo.algo_id()); +} + +// Determines whether we can safely perform a winograd non-fused convolution for +// the given input and output descriptors. This works around b/68264959, an +// integer overflow in cuDNNv5 and cuDNNv6. +static bool ShouldIncludeWinogradNonfusedAlgo( + const BatchDescriptor& input_descriptor, + const BatchDescriptor& output_descriptor) { + int64 batch = input_descriptor.count(); + int64 in_depths = input_descriptor.feature_map_count(); + int64 in_rows = input_descriptor.height(); + int64 in_cols = input_descriptor.width(); + int64 out_depths = output_descriptor.feature_map_count(); + + int64 total_size = 16 * std::ceil(batch / 16.0) * + std::max(in_depths, out_depths) * in_cols * in_rows * + sizeof(float); + int64 threshold = 1L << 31; + + return total_size < threshold; +} + tensorflow::Status ConvolutionThunk::ConvolveWithTune( const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, const FilterDescriptor& filter_descriptor, @@ -288,21 +314,29 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( const ConvolutionDescriptor& convolution_descriptor, const BufferAllocations& buffer_allocations, se::Stream* stream) { // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out. - if (best_algorithm_.algorithm().is_default()) { + if (!best_algorithm_.has_value()) { + best_algorithm_.emplace(); + // Auto-tuning either is disabled or only happens in the first run of this // function. VLOG(2) << "Profiling for best convolution algorithm used for " "ConvolutionThunk: " << this; + bool with_winograd_nonfused = + ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor); + se::dnn::ProfileResult best_result; se::dnn::ProfileResult best_result_without_scratch; - std::vector algorithms = GetAlgorithms(stream->parent()); + std::vector algorithms = + GetAlgorithms(with_winograd_nonfused, stream->parent()); for (auto algorithm : algorithms) { ConvolveScratchAllocator scratch_allocator( buffer_allocations.device_ordinal(), buffer_allocations.memory_allocator()); se::dnn::ProfileResult profile_result; + VLOG(3) << "Trying algorithm " << AlgorithmToString(algorithm) + << " for ConvolutionThunk: " << this; bool launch_ok = Convolve(input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, convolution_descriptor, @@ -310,6 +344,11 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( &scratch_allocator, &profile_result) .ok(); if (launch_ok && profile_result.is_valid()) { + VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) + << " for ConvolutionThunk " << this << " succeeded, taking " + << profile_result.elapsed_time_in_ms() + << "ms. (Best result: " << best_result.elapsed_time_in_ms() + << "ms)"; if (profile_result.elapsed_time_in_ms() < best_result.elapsed_time_in_ms()) { best_result = profile_result; @@ -319,39 +358,42 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( best_result_without_scratch.elapsed_time_in_ms()) { best_result_without_scratch = profile_result; } + } else { + VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) + << " for ConvolutionThunk " << this << " failed."; } } if (best_result.is_valid()) { - best_algorithm_.set_algorithm(best_result.algorithm()); + best_algorithm_->set_algorithm(best_result.algorithm()); } else { LOG(ERROR) << "No convolution algorithm works with profiling. Fall back " "to the default algorithm."; - best_algorithm_.set_algorithm(AlgorithmDesc()); + best_algorithm_->set_algorithm(AlgorithmDesc()); } if (best_result_without_scratch.is_valid()) { - best_algorithm_.set_algorithm_no_scratch( + best_algorithm_->set_algorithm_no_scratch( best_result_without_scratch.algorithm()); } else { LOG(ERROR) << "No convolution algorithm without scratch works with " "profiling. Fall back " "to the default algorithm."; - best_algorithm_.set_algorithm_no_scratch(AlgorithmDesc()); + best_algorithm_->set_algorithm_no_scratch(AlgorithmDesc()); } } { VLOG(2) << "Using convolution algorithm (" - << best_algorithm_.algorithm().algo_id() << ", " - << best_algorithm_.algorithm_no_scratch().algo_id() + << AlgorithmToString(best_algorithm_->algorithm()) << ", " + << AlgorithmToString(best_algorithm_->algorithm_no_scratch()) << ") for ConvolutionThunk: " << this; ConvolveScratchAllocator scratch_allocator( buffer_allocations.device_ordinal(), buffer_allocations.memory_allocator()); return Convolve(input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, - convolution_descriptor, best_algorithm_, stream, + convolution_descriptor, *best_algorithm_, stream, &scratch_allocator, nullptr); } } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 13432301b2af34ab4bd0864e39ce22366cc1d11d..7c25a2e6450e30292667ecd7de54b50ac2450767 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -87,6 +88,14 @@ class ConvolutionThunk : public Thunk { const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) override; + // Returns true if the next run of ExecuteOnStream will do autotuning. If so, + // we want the GPU to be quiescent during autotuning, so as not to introduce + // noise in our results. + bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream*) override { + return !best_algorithm_.has_value(); + } + private: tensorflow::Status ConvolveWithTune( const perftools::gputools::dnn::BatchDescriptor& input_descriptor, @@ -116,13 +125,15 @@ class ConvolutionThunk : public Thunk { // Returns the convolve algorithms that can be used for this ConvolutionThunk. std::vector GetAlgorithms( + bool with_winograd_nonfused, perftools::gputools::StreamExecutor* stream_exec) const; // Fastest cuDNN convolution algorithm for this thunk learned from // auto-tuning. If auto-tuning is disabled or failed, best_algorithm_ is set - // to the default value indicating cuDNN's convolution will choose - // the best algorithm from some heuristics based on its parameters. - perftools::gputools::dnn::AlgorithmConfig best_algorithm_; + // to the default value, indicating cuDNN's convolution will choose the best + // algorithm from some heuristics based on its parameters. + tensorflow::gtl::optional + best_algorithm_; const ConvolutionKind convolution_kind_; diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/copy_insertion.cc deleted file mode 100644 index 3dc85552015be67c20db9099704334c864b44b51..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/copy_insertion.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { -namespace gpu { - -StatusOr GpuCopyInsertion::Run(HloModule* module) { - TF_ASSIGN_OR_RETURN(bool changed, CopyInsertion::Run(module)); - - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // Make sure all operands of a library call are in memory instead of constants - // in IR. The top-level (index {}) of the points-to set of each operand - // indicates the source(s) of the array buffer. If any of these are constant, - // then add a copy to materialize the array. - HloComputation* computation = module->entry_computation(); - for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) { - if (ImplementedAsLibraryCall(*hlo)) { - for (int64 i = 0; i < hlo->operand_count(); ++i) { - HloInstruction* operand = hlo->mutable_operand(i); - const PointsToSet& points_to = - points_to_analysis->GetPointsToSet(operand); - const auto& element = points_to.element(/*index=*/{}); - if (std::any_of(element.begin(), element.end(), - [](const LogicalBuffer* buffer_source) { - return buffer_source->instruction()->opcode() == - HloOpcode::kConstant; - })) { - TF_ASSIGN_OR_RETURN(HloInstruction * copy, - CopyInsertion::FindOrInsertCopy(operand)); - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy)); - changed = true; - } - } - } - } - - return changed; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 1b94499bc6ef6d587cdb1fafec48bc4e5b917c51..6bf00cfb8a53723ae9608093480bf2eed10144dd 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -230,6 +230,66 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( } } +StatusOr GpuElementalIrEmitter::EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const { + PrimitiveType input_type = op->operand(0)->shape().element_type(); + TF_RET_CHECK(primitive_util::IsComplexType(input_type)); + PrimitiveType component_type = + primitive_util::ComplexComponentType(input_type); + switch (op->opcode()) { + case HloOpcode::kPower: { + // (a+bi)^(c+di) = + // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), + // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) + auto a = EmitExtractReal(lhs_value); + auto b = EmitExtractImag(lhs_value); + auto c = EmitExtractReal(rhs_value); + auto d = EmitExtractImag(rhs_value); + auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto half_c = ir_builder_->CreateFMul(one_half, c); + + TF_ASSIGN_OR_RETURN( + auto aa_p_bb_to_half_c, + EmitLibdeviceMathCall("__nv_pow", {aa_p_bb, half_c}, + {component_type, component_type}, + component_type)); + auto neg_d = ir_builder_->CreateFNeg(d); + TF_ASSIGN_OR_RETURN( + auto arg_lhs, EmitLibdeviceMathCall("__nv_atan2", {b, a}, + {component_type, component_type}, + component_type)); + auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); + TF_ASSIGN_OR_RETURN( + auto e_to_neg_d_arg_lhs, + EmitLibdeviceMathCall("__nv_exp", {neg_d_arg_lhs}, {component_type}, + component_type)); + auto coeff = + ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + TF_ASSIGN_OR_RETURN( + auto ln_aa_p_bb, + EmitLibdeviceMathCall("__nv_log", {aa_p_bb}, {component_type}, + component_type)); + auto half_d = ir_builder_->CreateFMul(one_half, d); + auto q = + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), + ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); + TF_ASSIGN_OR_RETURN( + auto cos_q, EmitLibdeviceMathCall("__nv_cos", {q}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_q, EmitLibdeviceMathCall("__nv_sin", {q}, {component_type}, + component_type)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), + ir_builder_->CreateFMul(coeff, sin_q)); + } + default: + return ElementalIrEmitter::EmitComplexBinaryOp(op, lhs_value, rhs_value); + } +} + StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { PrimitiveType input_type = op->operand(0)->shape().element_type(); @@ -237,18 +297,12 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( primitive_util::IsComplexType(input_type) ? primitive_util::ComplexComponentType(input_type) : input_type; - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { case HloOpcode::kLog: { // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), ir_builder_->CreateFMul(b, b)); @@ -261,34 +315,33 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( {component_type, component_type}, component_type)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return ComposeComplex(op, ir_builder_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); } - // TODO(b/65408531): Implement kPower on GPU, where atan2 is available. - // case HloOpcode::kPower: - // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(0.5(c+di)) case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto b = imag(operand_value); + auto b = EmitExtractImag(operand_value); TF_ASSIGN_OR_RETURN( - auto exp_a, EmitLibdeviceMathCall("__nv_exp", {real(operand_value)}, - {component_type}, component_type)); + auto exp_a, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractReal(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, component_type)); - return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); } case HloOpcode::kCos: { // cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = real(operand_value); + auto a = EmitExtractReal(operand_value); auto llvm_ty = a->getType(); TF_ASSIGN_OR_RETURN( - auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, - {component_type}, component_type)); + auto exp_b, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, component_type)); @@ -299,7 +352,7 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), @@ -309,11 +362,12 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( case HloOpcode::kSin: { // sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = real(operand_value); + auto a = EmitExtractReal(operand_value); auto llvm_ty = a->getType(); TF_ASSIGN_OR_RETURN( - auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, - {component_type}, component_type)); + auto exp_b, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, component_type)); @@ -324,13 +378,71 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); } + case HloOpcode::kTanh: { + /* + tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) + e^(a+bi) = e^a*(cos(b)+sin(b)i) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) + cos(b)=cos(-b), sin(-b)=-sin(b) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) + =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / + (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / + (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) + This is a complex division, so we can multiply by denom_conj/denom_conj + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * + (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + + i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + */ + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitLibdeviceMathCall("__nv_exp", {a}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, + component_type)); + auto exp_neg_a = ir_builder_->CreateFDiv( + llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(exp_a, exp_a), + ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); + auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); + auto real_num = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); + auto exp_a_plus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); + auto exp_a_minus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = ir_builder_->CreateFMul( + cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, + exp_a_minus_exp_neg_a_sq)); + auto denom = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), + ir_builder_->CreateFDiv(imag_num, denom)); + } default: return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value); } diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 3defa1b696d3addc012702e23102bb1fa140170d..6a537d015209bc507af36b13eeb5d69ce58d8fea 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -61,6 +61,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; + StatusOr EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const override; + StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const override; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 983cb872924f22be0dfad8aa9ad86f233b909c46..8c6a1f51a8a09ef78950dfe7e89994a3fe247f49 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -52,6 +52,15 @@ class GemmThunk : public Thunk { const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) override; + // Returns true if we'll perform autotuning if run on the given stream. If + // so, we want the GPU to be quiescent during autotuning, so as not to + // introduce noise in our results. + bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream* stream) override { + return autotune_results_.count( + stream->parent()->GetDeviceDescription().name()) != 0; + } + private: const BufferAllocation::Slice lhs_buffer_; const BufferAllocation::Slice rhs_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index ceb0e530c151219c7fef4dd6bfa36013cb53d63c..1ccfe323c58422c99fab5efa578be2a1e23e3d1b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include +#include #include #include @@ -30,17 +31,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" -#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" @@ -75,6 +77,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/tracing.h" namespace se = ::perftools::gputools; @@ -87,6 +90,7 @@ namespace gpu { namespace { +using tensorflow::port::Tracing; using tensorflow::strings::StrCat; // Any address of a variable residing in global memory or returned by one of the @@ -123,7 +127,7 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { // Runs optimization passes on the given HLO module. tensorflow::Status OptimizeHloModule( - HloModule* hlo_module, const se::DeviceDescription& device_desc, + HloModule* hlo_module, const HloCostAnalysis::ShapeSizeFunction& shape_size_function) { { HloPassPipeline pipeline("optimization"); @@ -134,7 +138,7 @@ tensorflow::Status OptimizeHloModule( // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); - + pipeline.AddPass(); { auto& pass = pipeline.AddPass>("simplification"); @@ -221,9 +225,8 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // (and sometime after) copy insertion, to avoid dead code from interfering // with the rewrites. pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); return pipeline.Run(hlo_module).status(); } @@ -231,6 +234,7 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // code (i.e. a cubin) as a byte array. StatusOr> CompilePtx(const string& ptx, int cc_major, int cc_minor) { + Tracing::TraceMe annotation("Compile PTX", /*is_expensive=*/true); const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); VLOG(2) << "Using ptxas at " << ptxas_path; @@ -255,7 +259,9 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, return InternalError("couldn't get temp CUBIN file name"); } auto cubin_cleaner = tensorflow::gtl::MakeCleanup([&cubin_path] { - TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(cubin_path)); + // CUBIN file may never be created, so the failure to delete it should not + // produce TF error. + tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError(); }); tensorflow::SubProcess ptxas_info_dumper; std::vector ptxas_args = {ptxas_path, ptx_path, "-o", cubin_path, @@ -289,15 +295,24 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, } // namespace GpuCompiler::GpuCompiler() - : pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} + : pointer_size_(llvm::DataLayout(kDataLayout) + .getPointerSize(0 /* default address space */)) {} + +StatusOr> GpuCompiler::RunHloPasses( + std::unique_ptr module, se::StreamExecutor* /*stream_exec*/) { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); + Tracing::TraceMe annotation("HLO Transforms", module->name(), + /*is_expensive=*/true); + TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), ShapeSizeBytesFunction())); + return std::move(module); +} -StatusOr> GpuCompiler::Compile( +StatusOr> GpuCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec) { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); + TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), - stream_exec->GetDeviceDescription(), - ShapeSizeBytesFunction())); TF_RETURN_IF_ERROR( PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction())); @@ -352,8 +367,11 @@ StatusOr> GpuCompiler::Compile( HloComputation* entry_computation = module->entry_computation(); IrEmitterUnnested ir_emitter(module->config(), entry_computation, &ir_emitter_context); - TF_RETURN_IF_ERROR( - entry_computation->root_instruction()->Accept(&ir_emitter)); + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); + TF_RETURN_IF_ERROR( + entry_computation->root_instruction()->Accept(&ir_emitter)); + } if (user_pre_optimization_hook_) { TF_CHECK_OK(user_pre_optimization_hook_(llvm_module)); @@ -402,9 +420,12 @@ StatusOr> GpuCompiler::Compile( cc_minor = 0; } - TF_ASSIGN_OR_RETURN(string ptx, - CompileToPtx(&llvm_module, {cc_major, cc_minor}, - module->config(), libdevice_dir)); + string ptx; + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - CompileToPtx"); + TF_ASSIGN_OR_RETURN(ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, + module->config(), libdevice_dir)); + } if (!ir_dump_directory.empty()) { TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( @@ -421,6 +442,22 @@ StatusOr> GpuCompiler::Compile( VLOG(2) << "PTX:"; XLA_VLOG_LINES(2, ptx); + // Write PTX to IR dump directory, if IR dumping was requested. + if (!ir_dump_directory.empty()) { + const string ptx_outfile = tensorflow::io::JoinPath( + ir_dump_directory, StrCat(module->name(), ".ptx")); + auto status = [&] { + auto* env = tensorflow::Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); + TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_outfile, ptx)); + return Status::OK(); + }(); + if (!status.ok()) { + LOG(WARNING) << "Couldn't dump PTX for module " << module->name() + << " to " << ptx_outfile << ": " << status; + } + } + const std::vector cubin = CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); @@ -430,10 +467,20 @@ StatusOr> GpuCompiler::Compile( VLOG(2) << "Printing the thunk schedule..."; XLA_VLOG_LINES(2, thunk_schedule->ToString()); - auto* gpu_executable = - new GpuExecutable(ptx, cubin, {cc_major, cc_minor}, - std::move(thunk_schedule), std::move(module), - std::move(buffer_assignment), ShapeSizeBytesFunction()); + std::unique_ptr profile_index_map; + std::unique_ptr profile_printer; + + if (module->config().hlo_profiling_enabled()) { + HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); + profile_index_map = MakeUnique(*module); + profile_printer = + CreateHloProfilePrinter(*profile_index_map, cost_analysis); + } + + auto* gpu_executable = new GpuExecutable( + ptx, cubin, {cc_major, cc_minor}, std::move(thunk_schedule), + std::move(module), std::move(buffer_assignment), + std::move(profile_printer), std::move(profile_index_map)); if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); @@ -444,6 +491,8 @@ StatusOr> GpuCompiler::Compile( std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, int cc_major, int cc_minor) { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::CompilePtxOrGetCachedResult"); + Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true); bool inserted; decltype(compilation_cache_.begin()) iter; // Pointers into compilation_cache_ where the ptx and (optional) cubin are @@ -476,10 +525,24 @@ std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, VLOG(2) << "Compiled PTX size:" << ptx.size() << " CUBIN size: " << cache_value->cubin_data.size(); } else { - LOG(WARNING) - << "Failed to compile ptx to cubin. Will attempt to let " - "GPU driver compile the ptx. " - << maybe_cubin.status(); + bool log_warning = true; + if (maybe_cubin.status().code() == + tensorflow::error::Code::NOT_FOUND) { + // Missing ptxas is expected in some environments where CUDA SDK + // binaries are not available. We don't want to spam logs with + // identical warnings in this case. + + // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N + // for more general usage. + static std::atomic warning_done(false); + log_warning = !warning_done.exchange(true); + } + if (log_warning) { + LOG(WARNING) + << "Failed to compile ptx to cubin. Will attempt to let " + "GPU driver compile the ptx. " + << maybe_cubin.status(); + } } } cache_value->compilation_done = true; @@ -496,13 +559,6 @@ std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, return cache_value->cubin_data; } -StatusOr>> GpuCompiler::Compile( - std::vector> modules, - std::vector> stream_execs) { - return Unimplemented( - "Compilation of multiple HLO modules is not yet supported on GPU."); -} - StatusOr>> GpuCompiler::CompileAheadOfTime(std::vector> module, const AotCompilationOptions& options) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index ee67e65caf2434fc74503d07c6fccb98de70d96c..18e34340205b6f51497e26c45520799d21c55a46 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -42,14 +42,20 @@ class GpuCompiler : public LLVMCompiler { GpuCompiler(); ~GpuCompiler() override {} - StatusOr> Compile( + // Bring in + // StatusOr>> Compile( + // std::vector> modules, + // std::vector> + // stream_execs) + using LLVMCompiler::Compile; + + StatusOr> RunHloPasses( std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; - StatusOr>> Compile( - std::vector> modules, - std::vector> - stream_execs) override; + StatusOr> RunBackend( + std::unique_ptr module, + perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> CompileAheadOfTime(std::vector> module, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc new file mode 100644 index 0000000000000000000000000000000000000000..33d739b79d3664fec3586bbc924b7fa2e10d3256 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace gpu { + +StatusOr GpuCopyInsertion::FindOrInsertCopy( + HloInstruction* hlo) { + HloInstruction*& copy = inserted_copies_[hlo]; + if (copy == nullptr) { + TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo)); + } + return copy; +} + +StatusOr GpuCopyInsertion::Run(HloModule* module) { + CopyInsertion generic_copy_insertion; + + TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, + HloDataflowAnalysis::Run(module)); + + // Make sure all operands of a library call are in memory instead of constants + // in IR. + for (HloInstruction* hlo : + module->entry_computation()->MakeInstructionPostOrder()) { + if (ImplementedAsLibraryCall(*hlo)) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { + HloInstruction* operand = hlo->mutable_operand(i); + TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); + const auto& values = dataflow->GetValueSet(operand).values(); + if (std::any_of(values.begin(), values.end(), + [](const HloValue* value) { + return value->defining_instruction()->opcode() == + HloOpcode::kConstant; + })) { + TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy)); + changed = true; + } + } + } + } + + // Init values of a while node cannot be constants. Insert copies for any + // constants found at the operand of a while. + tensorflow::gtl::FlatSet copied_constants; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + for (auto& pair : + dataflow->GetInstructionValueSet(instruction->operand(0))) { + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (value->defining_instruction()->opcode() == + HloOpcode::kConstant && + !ContainsKey(copied_constants, value->defining_instruction())) { + HloInstruction* constant = value->defining_instruction(); + TF_ASSIGN_OR_RETURN(HloInstruction * copy, + FindOrInsertCopy(constant)); + TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); + copied_constants.insert(constant); + changed = true; + } + } + } + } + } + + // The GPU backend needs additional copies added due to deficiencies in + // buffer assignment. + TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed, + CopyInsertion::AddCopiesForBufferAssignment(module)); + + return changed || buffer_assignment_changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h similarity index 56% rename from tensorflow/compiler/xla/service/gpu/copy_insertion.h rename to tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 11077dad2e5506eab4fa84d47ad13a26ed1c035a..4d77f337e6eb20f7d79acc0829fde26bbe443f25 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_ -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { namespace gpu { @@ -25,12 +25,23 @@ namespace gpu { // Besides the modifications made by the generic xla::CopyInsertion, this // GPU-specific copy insertion also materializes operands of library calls by // inserting kCopy instructions. -class GpuCopyInsertion : public CopyInsertion { +class GpuCopyInsertion : public HloPassInterface { public: + tensorflow::StringPiece name() const override { return "copy-insertion"; } + StatusOr Run(HloModule* module) override; + + protected: + // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making + // duplicate copies. + StatusOr FindOrInsertCopy(HloInstruction* hlo); + + // A map containing all copies inserted to materialize operands of library + // calls. The key is the copied instruction and the value is the copy. + tensorflow::gtl::FlatMap inserted_copies_; }; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index c6f23f9b0506186c4f76a887e6a540dafdd79962..b802ae9c7aba4e94bb37b0e4c6a2ba157f9be7d5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -69,7 +69,7 @@ class HloExecutionProfiler { ~HloExecutionProfiler() { if (do_profile_) { stream_->ThenStopTimer(execution_timer_.get()); - stream_->BlockHostUntilDone(); + stream_->BlockHostUntilDone().IgnoreError(); profile_->set_total_cycles_executed( *computation_, execution_timer_->Nanoseconds() * clock_rate_ghz_); } @@ -87,7 +87,7 @@ class HloExecutionProfiler { void FinishOperation(const HloInstruction* hlo_instruction) { if (do_profile_) { stream_->ThenStopTimer(per_op_timer_.get()); - stream_->BlockHostUntilDone(); + stream_->BlockHostUntilDone().IgnoreError(); profile_->SetCyclesTakenBy( hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_); } @@ -113,14 +113,15 @@ GpuExecutable::GpuExecutable( std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, - HloCostAnalysis::ShapeSizeFunction shape_size_function) - : Executable(std::move(hlo_module)), + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + std::move(hlo_profile_index_map)), ptx_(ptx), cubin_(cubin), compute_capability_(compute_capability), thunk_schedule_(std::move(thunk_schedule)), - assignment_(std::move(assignment)), - shape_size_function_(std::move(shape_size_function)) {} + assignment_(std::move(assignment)) {} Status GpuExecutable::ExecuteThunks( const ServiceExecutableRunOptions* run_options, @@ -166,9 +167,16 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } + // If this thunk requests it, wait for all currently-executing thunks to + // finish. This is useful e.g. if the thunk is about to perform autotuning. + if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { + TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); + } + profiler.StartOperation(); VLOG(2) << "Executing the thunk for " - << thunk->hlo_instruction()->ToString(); + << thunk->hlo_instruction()->ToString() << " on stream " + << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); if (thunk_schedule_->Depended(thunk)) { auto finish_event = MakeUnique(main_stream->parent()); @@ -183,9 +191,13 @@ Status GpuExecutable::ExecuteThunks( // Make sure kernels are completed before deallocating temporary buffers. // TODO(b/30100571): we could potentially postpone deallocating the temp // buffers until a different computation is executed. - if (block_host_until_done && !main_stream->BlockHostUntilDone()) { - return InternalError("Failed to complete all kernels launched on stream %p", - main_stream); + if (block_host_until_done) { + Status block_status = main_stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError( + "Failed to complete all kernels launched on stream %p: %s", + main_stream, block_status.error_message().c_str()); + } } return Status::OK(); @@ -358,9 +370,5 @@ const PointsToSet& GpuExecutable::GetRootPointsToSet() const { module().entry_computation()->root_instruction()); } -std::unique_ptr GpuExecutable::CreateCostAnalysis() const { - return MakeUnique(shape_size_function_); -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index a3815370c19af1da612bc6d9663cc0f8896062f7..e7307e07c0b5608e31f15597d31d11c50f81c6d5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -54,7 +54,8 @@ class GpuExecutable : public Executable { std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, - HloCostAnalysis::ShapeSizeFunction shape_size_function); + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map); // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -95,8 +96,6 @@ class GpuExecutable : public Executable { return Unimplemented("Equality test on GPU executable is not implemented."); } - std::unique_ptr CreateCostAnalysis() const override; - private: // If `block_host_until_done` is false, execution will not block the host // until the kernels have completed. This is used as an optimization for @@ -140,9 +139,6 @@ class GpuExecutable : public Executable { // memory for every output/temp buffers. const std::unique_ptr assignment_; - // Function to compute the size of a given Shape, in bytes. - const HloCostAnalysis::ShapeSizeFunction shape_size_function_; - TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc similarity index 93% rename from tensorflow/compiler/xla/service/gpu/layout_assignment.cc rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 0bbd63fb7bfc657cb7bb1de673253c198f5bd25f..50a249f448e7b4956e7bf6bd603d256eca88f71d 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include @@ -80,9 +80,9 @@ Status GpuLayoutAssignment::AddBackendConstraints( const ConvolutionDimensionNumbers& dimension_numbers = instruction->convolution_dimension_numbers(); std::vector input_layout; - for (int i = dimension_numbers.spatial_dimensions_size() - 1; i >= 0; - --i) { - input_layout.push_back(dimension_numbers.spatial_dimensions(i)); + for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; + i >= 0; --i) { + input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); } input_layout.push_back(dimension_numbers.input_feature_dimension()); input_layout.push_back(dimension_numbers.input_batch_dimension()); @@ -102,9 +102,9 @@ Status GpuLayoutAssignment::AddBackendConstraints( *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); std::vector output_layout; - for (int i = dimension_numbers.spatial_dimensions_size() - 1; i >= 0; - --i) { - output_layout.push_back(dimension_numbers.spatial_dimensions(i)); + for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; + i >= 0; --i) { + output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); } output_layout.push_back(dimension_numbers.output_feature_dimension()); output_layout.push_back(dimension_numbers.output_batch_dimension()); diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h similarity index 86% rename from tensorflow/compiler/xla/service/gpu/layout_assignment.h rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 169041eb85c633cb4f1f679bcea127714828308f..7655a3ebf45f83c0125a4257baae7a7229ebdc6d 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -38,4 +38,4 @@ class GpuLayoutAssignment : public LayoutAssignment { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc similarity index 97% rename from tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index ac206b89d329d7e4ac91ee51162c9694f6899d78..f68b23c8ce969372a01ce77840e016d82ca5d2ed 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index f0f036f7f381db15b84db85d3efeec5d8141884e..ae92daef8882de2e7d64b69f68452061cb5507f2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -44,7 +44,7 @@ GpuTransferManager::GpuTransferManager() : GenericTransferManager( se::cuda::kCudaPlatformId, /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) - .getPointerSize()) {} + .getPointerSize(0 /* default address space */)) {} Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { @@ -105,12 +105,13 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( // infeed requests, blocking on the stream might be // heavy-handed. Figure out if finer-grained acknowledgement is // possible. - if (!stream->BlockHostUntilDone()) { + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { for (gpu::InfeedBuffer* b : buffers) { b->Done(); } - return InternalError("Failed to complete data transfer on stream %p", - stream); + return InternalError("Failed to complete data transfer on stream %p: %s", + stream, block_status.error_message().c_str()); } infeed_manager->EnqueueBuffers(buffers); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 163a161353fdb90cee2968269d572b8414855551..c2115c49993ef71c4b6dd584e7e0498807666613 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -166,11 +166,46 @@ void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value; } +// Determines whether hlo's buffers are never modified within the execution of +// consumer. +static bool BuffersInvariantWithinConsumer( + const HloInstruction& hlo, const HloInstruction& consumer, + const BufferAssignment* buffer_assignment) { + // Check if consumer is inside a fusion node -- if so, "dereference" it until + // we get to a non-fusion node. + const HloInstruction* c = &consumer; + while (c->IsFused()) { + c = c->parent()->FusionInstruction(); + } + + // If, after dereferencing c, we end up with a node that's not inside our + // module's top-level computation (say our node is inside a while loop), we + // give up on marking array as invariant, because this HLO may be run multiple + // times (e.g. multiple while loop iterations, or multiple invocations of a + // reducer's computation). TODO(jlebar): We could relax this constraint if we + // emitted an llvm.invariant.group.barrier at the end of the computation. + return c->parent() == c->GetModule()->entry_computation() && + buffer_assignment->HaveDisjointSlices(&hlo, &consumer); +} + llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, + const HloInstruction& consumer, const ShapeIndex& shape_index) { llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index), ShapeUtil::GetSubshape(hlo.shape(), shape_index)); alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); + + // The GPU backend emits one kernel per top-level HLO, and LLVM views + // execution of one kernel as the "whole program" executed on the GPU. + // Therefore if hlo's output buffer is not modified within consumer, and if + // consumer runs hlo only once (so that it doesn't create two different + // outputs), then we can mark ir_array as invariant over the whole program. + if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) { + VLOG(2) << "Marking " << hlo.name() << " as invariant within " + << consumer.name(); + ir_array.MarkInvariantOverWholeProgram(&module_->getContext()); + } + return ir_array; } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index a3120f15bcbfb0f2f0bfbd806e7a4ff05316d5dd..62ae1769a1f2fb3b9acaf35bdf18a793232500b0 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -76,8 +76,15 @@ class HloToIrBindings { return it->second.element(shape_index); } - // Return the underlying IrArray of the output of the given instruction. + // Returns the IrArray which contains the output of hlo. + // + // consumer is the HLO in which this IrArray is used -- we use this to (try + // to) add metadata indicating that the array is invariant within consumer. + // + // To get the buffer into which hlo should write its own output, call + // GetIrArray(hlo, hlo). llvm_ir::IrArray GetIrArray(const HloInstruction& hlo, + const HloInstruction& consumer, const ShapeIndex& shape_index = {}); private: diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index e33e904692ca5ad41e17d2e165dbb40b6bd4aa33..2ac95ceb692447c7ac6dbbcd8b9a38876f7a77b6 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -30,9 +30,8 @@ InfeedThunk::InfeedThunk( tuple_element_buffers.end()), destination_buffer_(destination_buffer) {} -tensorflow::Status InfeedThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { +Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { VLOG(2) << "Infeeding to GPU "; perftools::gputools::DeviceMemoryBase destination_address = @@ -66,15 +65,16 @@ tensorflow::Status InfeedThunk::ExecuteOnStream( buffer->length()); } - if (!stream->BlockHostUntilDone()) { - return InternalError("Failed to complete data transfer on stream %p", - stream); + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("Failed to complete data transfer on stream %p: %s", + stream, block_status.error_message().c_str()); } infeed_manager->ReleaseBuffers(infeed_buffers); VLOG(2) << "Infeeding to GPU complete"; - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 371d71f9dbdd21cb5f36cc3108c8f398a4a91c29..86918705fa0305217f11753e383200c7bd71474b 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -43,9 +43,8 @@ class InfeedThunk : public Thunk { InfeedThunk(const InfeedThunk&) = delete; InfeedThunk& operator=(const InfeedThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 9a4bfd0905bb62c02c70e7f2eea46872c07bca89..1d47ffde4331868cbc8a8afb2d01b11e77a7fab0 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -156,8 +156,10 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { conv_dnums.set_output_batch_dimension(0); conv_dnums.set_input_feature_dimension(1); conv_dnums.set_output_feature_dimension(1); - conv_dnums.add_spatial_dimensions(2); - conv_dnums.add_spatial_dimensions(3); + conv_dnums.add_input_spatial_dimensions(2); + conv_dnums.add_output_spatial_dimensions(2); + conv_dnums.add_input_spatial_dimensions(3); + conv_dnums.add_output_spatial_dimensions(3); conv_dnums.set_kernel_output_feature_dimension(0); conv_dnums.set_kernel_input_feature_dimension(1); conv_dnums.add_kernel_spatial_dimensions(2); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 8fb7a6adda9dc7c36eb9aabcbcdc9d77e6c22c4a..658fd05cd4b63c923d21b4a1de16468c0aeec65d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -100,7 +100,7 @@ bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { if (hlo.opcode() == HloOpcode::kConvolution) { const ConvolutionDimensionNumbers& dnums = hlo.convolution_dimension_numbers(); - if (dnums.spatial_dimensions_size() > 3) { + if (dnums.input_spatial_dimensions_size() > 3) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 57a3f713e35b506ad9d5caab1ced2c7b74f8efcf..f64e93024fe134e585411f555810711763f6fcb5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -68,7 +68,8 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArray(*operand, *hlo) + .EmitReadArrayElement(index, &ir_builder_); }; } return EmitTargetElementLoop( @@ -128,16 +129,25 @@ Status IrEmitter::HandleSend(HloInstruction*) { return Unimplemented("Send is not implemented on GPU"); } +Status IrEmitter::HandleSendDone(HloInstruction*) { + return Unimplemented("Send-Done is not implemented on GPU"); +} + Status IrEmitter::HandleRecv(HloInstruction*) { return Unimplemented("Recv is not implemented on GPU"); } +Status IrEmitter::HandleRecvDone(HloInstruction*) { + return Unimplemented("Recv-done is not implemented on GPU"); +} + Status IrEmitter::HandleTuple(HloInstruction* tuple) { std::vector base_ptrs; for (const HloInstruction* operand : tuple->operands()) { base_ptrs.push_back(GetBasePointer(*operand)); } - llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_, module_); + llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &ir_builder_, + module_); return Status::OK(); } @@ -163,7 +173,7 @@ Status IrEmitter::EmitCallToNestedComputation( return Status::OK(); } -bool IrEmitter::MaybeEmitSpecialAtomicOperation( +bool IrEmitter::MaybeEmitDirectAtomicOperation( const HloComputation& computation, llvm::Value* output_address, llvm::Value* source_address) { CHECK_EQ(2, computation.num_parameters()); @@ -223,101 +233,189 @@ bool IrEmitter::MaybeEmitSpecialAtomicOperation( return false; } -Status IrEmitter::EmitAtomicOperationForNestedComputation( - const HloComputation& computation, llvm::Value* output_address, - llvm::Value* source_address) { - if (computation.num_parameters() != 2) { - // TODO(b/30258929): We only accept binary computations so far. - return Unimplemented( - "We only support atomic functions with exactly two parameters, but " - "computation %s has %lld.", - computation.name().c_str(), computation.num_parameters()); - } - - if (MaybeEmitSpecialAtomicOperation(computation, output_address, - source_address)) { - return Status::OK(); - } +// Implements atomic binary operations using atomic compare-and-swap +// (atomicCAS) as follows: +// 1. Reads the value from the memory pointed to by output_address and +// records it as old_output. +// 2. Uses old_output as one of the source operand to perform the binary +// operation and stores the result in new_output. +// 3. Calls atomicCAS which implements compare-and-swap as an atomic +// operation. In particular, atomicCAS reads the value from the memory +// pointed to by output_address, and compares the value with old_output. If +// the two values equal, new_output is written to the same memory location +// and true is returned to indicate that the atomic operation succeeds. +// Otherwise, the new value read from the memory is returned. In this case, +// the new value is copied to old_output, and steps 2. and 3. are repeated +// until atomicCAS succeeds. +// +// On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If +// the element type of the binary operation is 32 bits or 64 bits, the integer +// type of the same size is used for the atomicCAS operation. On the other hand, +// if the element type is smaller than 32 bits, int32 is used for the atomicCAS +// operation. In this case, atomicCAS reads and writes 32 bit values from +// the memory, which is larger than the memory size required by the original +// atomic binary operation. We mask off the last two bits of the output_address +// and use the result as an address to read the 32 bit values from the memory. +// This can avoid out of bound memory accesses if tensor buffers are 4 byte +// aligned and have a size of 4N, an assumption that the runtime can guarantee. +// +// The pseudo code is shown below. Variables *_address are pointers to a memory +// region with a size equal to the size of the atomicCAS operation, with the +// exception that new_output_address is a pointer to a memory region with a size +// equal to the element size of the binary operation. +// +// element_size = sizeof(element_type); +// atomic_size = max(32, element_size); +// cas_new_output_address = alloca(atomic_size); +// cas_old_output_address = alloca(atomic_size); +// if (atomic_size != element_size) { +// atomic_address = output_address & ((int64)(-2)); +// new_output_address = cas_new_output_address + (output_address & 3); +// } else { +// atomic_address = output_address; +// new_output_address = cas_new_output_address; +// } +// +// *cas_old_output_address = *atomic_address; +// do { +// *cas_new_output_address = *cas_old_output_address; +// *new_output_address = operation(*new_output_address, *source_address); +// (*cas_old_output_address, success) = +// atomicCAS(atomic_address, *cas_old_output_address, +// *cas_new_output_address); +// } while (!success); +// +Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address) { + llvm::PointerType* output_address_type = + llvm::dyn_cast(output_address->getType()); + CHECK_NE(output_address_type, nullptr); + + // element_type is the data type for the binary operation. + llvm::Type* element_type = output_address_type->getPointerElementType(); + int element_size = llvm_ir::GetSizeInBits(element_type); + llvm::Type* element_address_type = element_type->getPointerTo(); + + int atomic_size = (element_size < 32) ? 32 : element_size; + llvm::Type* atomic_type = ir_builder_.getIntNTy(atomic_size); + llvm::Type* atomic_address_type = + atomic_type->getPointerTo(output_address_type->getPointerAddressSpace()); + + // cas_old_output_address and cas_new_output_address point to the scratch + // memory where we store the old and new values for the repeated atomicCAS + // operations. + llvm::Value* cas_old_output_address = ir_builder_.CreateAlloca( + atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); + llvm::Value* cas_new_output_address = ir_builder_.CreateAlloca( + atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); - // Other binary computations can be made atomic as following (labels are basic - // block names used in the IR emitting code later). - // - // atomic_op_loop_preheader: - // ... - // source = *source_address; - // old_output = *output_address; - // do { - // atomic_op_loop_body_entry: - // new_output = computation(old_output, source); - // (old_output, success) = - // atomicCAS(output_address, old_output, new_output); - // } while (!success); - // - // atomic_op_loop_exit: - // ... - // - // TODO(jingyue): Consider encapsulate the logic of emitting control flow to - // something similar to llvm_ir::ForLoop. - // // Emit preparation code to the preheader. llvm::BasicBlock* loop_preheader_bb = ir_builder_.GetInsertBlock(); - llvm::Type* element_ir_type = - output_address->getType()->getPointerElementType(); - // old_output = *output_address; - llvm::Value* old_output_location = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "old_output_location"); - ir_builder_.CreateStore(ir_builder_.CreateLoad(output_address, "old_output"), - old_output_location); + + llvm::Value* atomic_memory_address; + // binop_output_address points to the scratch memory that stores the + // result of the binary operation. + llvm::Value* binop_output_address; + if (element_size < 32) { + // Assume the element size is an integer number of bytes. + CHECK_EQ((element_size % sizeof(char)), 0); + llvm::Type* address_int_type = + module_->getDataLayout().getIntPtrType(output_address_type); + atomic_memory_address = + ir_builder_.CreatePtrToInt(output_address, address_int_type); + llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); + llvm::Value* offset = ir_builder_.CreateAnd(atomic_memory_address, mask); + mask = llvm::ConstantInt::get(address_int_type, -2); + atomic_memory_address = ir_builder_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = + ir_builder_.CreateIntToPtr(atomic_memory_address, atomic_address_type); + binop_output_address = ir_builder_.CreateAdd( + ir_builder_.CreatePtrToInt(cas_new_output_address, address_int_type), + offset); + binop_output_address = + ir_builder_.CreateIntToPtr(binop_output_address, element_address_type); + } else { + atomic_memory_address = + ir_builder_.CreateBitCast(output_address, atomic_address_type); + binop_output_address = + ir_builder_.CreateBitCast(cas_new_output_address, element_address_type); + } + + // Use the value from the memory that atomicCAS operates on to initialize + // cas_old_output. + llvm::Value* cas_old_output = + ir_builder_.CreateLoad(atomic_memory_address, "cas_old_output"); + ir_builder_.CreateStore(cas_old_output, cas_old_output_address); + llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( ir_builder_.GetInsertPoint(), "atomic_op_loop_exit"); - - // Emit the body of the loop that repeatedly invokes atomicCAS. llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(ir_builder_.getContext(), "atomic_op_loop_body", ir_builder_.GetInsertBlock()->getParent()); ir_builder_.SetInsertPoint(loop_body_bb); // Change preheader's successor from loop_exit_bb to loop_body_bb. loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb); - // new_output = computation(old_output, source); - llvm::Value* new_output_location = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "new_output_location"); + + // Emit the body of the loop that repeatedly invokes atomicCAS. + // + // Use cas_old_output to initialize cas_new_output. + cas_old_output = + ir_builder_.CreateLoad(cas_old_output_address, "cas_old_output"); + ir_builder_.CreateStore(cas_old_output, cas_new_output_address); + // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - computation, {old_output_location, source_address}, new_output_location)); - - // (old_output, success) = atomicCAS(output_address, old_output, new_output); - llvm::Type* element_int_ir_type = - ir_builder_.getIntNTy(element_ir_type->getScalarSizeInBits()); - // cmpxchg accetps integer only, so we bitcast the operands (old_output and - // new_output) to integers of the same bit width, and bitcast the result - // back to the original element type. - llvm::Value* old_output = - ir_builder_.CreateLoad(old_output_location, "old_output"); - llvm::Value* new_output = - ir_builder_.CreateLoad(new_output_location, "new_output"); + computation, {binop_output_address, source_address}, + binop_output_address)); + + llvm::Value* cas_new_output = + ir_builder_.CreateLoad(cas_new_output_address, "cas_new_output"); + + // Emit code to perform the atomicCAS operation + // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, + // cas_new_output); llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg( - ir_builder_.CreateBitCast(output_address, - element_int_ir_type->getPointerTo()), - ir_builder_.CreateBitCast(old_output, element_int_ir_type), - ir_builder_.CreateBitCast(new_output, element_int_ir_type), + atomic_memory_address, cas_old_output, cas_new_output, llvm::AtomicOrdering::SequentiallyConsistent, llvm::AtomicOrdering::SequentiallyConsistent); - // cmpxchg returns a pair. The first element is the original value at - // output_address and the second element is whether the swap is successful. + + // Extract the memory value returned from atomicCAS and store it as + // cas_old_output. ir_builder_.CreateStore( - ir_builder_.CreateBitCast( - ir_builder_.CreateExtractValue(ret_value, 0, "old_output"), - element_ir_type), - old_output_location); + ir_builder_.CreateExtractValue(ret_value, 0, "cas_old_output"), + cas_old_output_address); + // Extract the success bit returned from atomicCAS and generate a + // conditional branch on the success bit. ir_builder_.CreateCondBr( ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); - // Restore the insertion point to the exit basic block so that the caller of + // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. SetToFirstInsertPoint(loop_exit_bb, &ir_builder_); return Status::OK(); } +Status IrEmitter::EmitAtomicOperationForNestedComputation( + const HloComputation& computation, llvm::Value* output_address, + llvm::Value* source_address) { + if (computation.num_parameters() != 2) { + // TODO(b/30258929): We only accept binary computations so far. + return Unimplemented( + "We only support atomic functions with exactly two parameters, but " + "computation %s has %lld.", + computation.name().c_str(), computation.num_parameters()); + } + + if (MaybeEmitDirectAtomicOperation(computation, output_address, + source_address)) { + return Status::OK(); + } + + return EmitAtomicOperationUsingCAS(computation, output_address, + source_address); +} + Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); auto on_true = select->operand(1); @@ -325,7 +423,8 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { TF_RET_CHECK(pred->shape().element_type() == PRED); if (ShapeUtil::IsTuple(select->shape())) { - llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred), + llvm_ir::EmitTupleSelect(GetIrArray(*select, *select), + GetIrArray(*pred, *select), GetBasePointer(*on_true), GetBasePointer(*on_false), &ir_builder_, module_); return Status::OK(); @@ -340,9 +439,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { Status IrEmitter::HandleDot(HloInstruction* dot) { auto lhs_instruction = dot->operand(0); auto rhs_instruction = dot->operand(1); - const llvm_ir::IrArray& target_array = GetIrArray(*dot); - const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction); - const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction); + const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot); + const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot); + const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot); const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); @@ -562,7 +661,8 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // Apply the reduction function to the loaded value. llvm::Value* input_address = - GetIrArray(*arg).EmitArrayElementAddress(input_index, &ir_builder_); + GetIrArray(*arg, *reduce) + .EmitArrayElementAddress(input_index, &ir_builder_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *function, {accumulator_addr, input_address}, accumulator_addr)); @@ -578,7 +678,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand)); + parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &ir_builder_, GetNestedComputer()); @@ -613,7 +713,8 @@ Status IrEmitter::HandleRng(HloInstruction* random) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : random->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArray(*operand, *random) + .EmitReadArrayElement(index, &ir_builder_); }; } // Emits a single-threaded loop because the loop body generated by the element @@ -622,10 +723,41 @@ Status IrEmitter::HandleRng(HloInstruction* random) { GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, GetNestedComputer()) .MakeElementGenerator(random, operand_to_generator), - GetIrArray(*random), &ir_builder_) + GetIrArray(*random, *random), &ir_builder_) .EmitLoop(IrName(random)); } +Status IrEmitter::HandleConditional(HloInstruction* conditional) { + auto pred = conditional->operand(0); + auto true_arg = conditional->operand(1); + auto false_arg = conditional->operand(2); + + llvm::Value* conditional_result = GetBasePointer(*conditional); + + llvm::LoadInst* pred_value = ir_builder_.CreateLoad( + GetBasePointer(*pred), + llvm_ir::AsStringRef(IrName(conditional, "load_predicate_value"))); + llvm::Value* pred_cond = ir_builder_.CreateICmpNE( + pred_value, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), + llvm_ir::AsStringRef(IrName(conditional, "boolean_predicate"))); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + pred_cond, IrName(conditional, "if_then_else"), &ir_builder_); + + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *conditional->true_computation(), {GetBasePointer(*true_arg)}, + conditional_result)); + + SetToFirstInsertPoint(if_data.false_block, &ir_builder_); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *conditional->false_computation(), {GetBasePointer(*false_arg)}, + conditional_result)); + + SetToFirstInsertPoint(if_data.after_block, &ir_builder_); + return Status::OK(); +} + llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 263992d92544166c0d08a6c60b43e78f10f06aed..08bbbe36c72872ba68104c8f328c2f602eb30fa8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -84,7 +84,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSort(HloInstruction* sort) override; Status HandleSend(HloInstruction* send) override; + Status HandleSendDone(HloInstruction* send_done) override; Status HandleRecv(HloInstruction* recv) override; + Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; @@ -93,6 +95,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleRng(HloInstruction* random) override; + Status HandleConditional(HloInstruction* conditional) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } @@ -103,10 +106,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { explicit IrEmitter(const HloModuleConfig& hlo_module_config, IrEmitterContext* ir_emitter_context, bool is_nested); - // A convenient helper for calling HloToIrBindings::GetIrArray. + // Helper for calling HloToIrBindings::GetIrArray. + // + // Gets the IrArray which contains inst. This array has metadata that makes + // it valid only within the IR that implements consumer. If you are + // implementing an HLO and want to get its own output buffer, call + // GetIrArray(hlo, hlo). llvm_ir::IrArray GetIrArray(const HloInstruction& inst, + const HloInstruction& consumer, const ShapeIndex& shape_index = {}) { - return bindings_.GetIrArray(inst, shape_index); + return bindings_.GetIrArray(inst, consumer, shape_index); } // A convenient helper for calling HloToIrBindings::GetBasePointer. llvm::Value* GetBasePointer(const HloInstruction& inst) const { @@ -177,9 +186,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { // be simply implemented using an LLVM atomic instruction. If "computation" is // one of this kind, emits code to do that and returns true; otherwise, // returns false. - bool MaybeEmitSpecialAtomicOperation(const HloComputation& computation, - llvm::Value* output_address, - llvm::Value* source_address); + bool MaybeEmitDirectAtomicOperation(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address); + + // A helper method for EmitAtomicOperationForNestedComputation. It implements + // binary atomic operations using atomicCAS with special handling to support + // small data types. + Status EmitAtomicOperationUsingCAS(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address); StatusOr ComputeNestedElement( const HloComputation& computation, @@ -219,6 +235,7 @@ class IrEmitterUnnested : public IrEmitter { // IrEmitterUnnested handles the following instructions differently from // IrEmitter. Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleDot(HloInstruction* dot) override; Status HandleFusion(HloInstruction* fusion) override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5da1a130d5654b86803396b07a6501c59a182c67..5225ff36ff3a8a1b049479c34aa301de8724f73e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -115,7 +115,8 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { Status IrEmitterNested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { - return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo), &ir_builder_) + return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), + &ir_builder_) .EmitLoop(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 7b4662fc80c5518135c827489a3724e477b2bad1..8dbc90ee1fb5678f070bdc8999ffa8980197188f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -123,10 +123,12 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get( llvm::IntegerType::get(llvm_context, /*NumBits=*/32), launch_dims.threads_per_block()); + // Our launch bounds are exact, so we can specify them as reqntidx rather than + // maxntidx. nvvm_annotations_node->addOperand(llvm::MDNode::get( llvm_context, {llvm::ConstantAsMetadata::get(ir_kernel), - llvm::MDString::get(llvm_context, "maxntidx"), + llvm::MDString::get(llvm_context, "reqntidx"), llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } } // namespace @@ -246,6 +248,11 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); @@ -254,6 +261,11 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { return IrEmitter::HandleDot(dot); } +Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { + thunk_sequence_->push_back(BuildKernelThunk(conditional)); + return IrEmitter::HandleConditional(conditional); +} + Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { if (ImplementedAsDnnConvolution(*convolution)) { thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution)); @@ -282,7 +294,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { MakeUnique(std::move(thunks), fusion)); std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand)); + parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter( hlo_module_config_, ir_emitter_context_->llvm_module(), @@ -344,7 +356,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); std::vector operand_arrays; for (HloInstruction* operand : fusion->operands()) { - operand_arrays.push_back(GetIrArray(*operand)); + operand_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), @@ -355,7 +367,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // Array to write into. Because this is an in-place operation, this is the // same as operand 0's array. - llvm_ir::IrArray output_array = GetIrArray(*fusion); + llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( update_shape, ir_emitter_context_->device_description()); @@ -693,9 +705,10 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { constexpr int64 tile_size = 32; constexpr int64 num_rows = 8; int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*(copy->operand(0))) + GetIrArray(*copy->operand(0), *copy) .CastToShape(reduced_input_shape, &ir_builder_), - GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_), + GetIrArray(*copy, *copy) + .CastToShape(reduced_output_shape, &ir_builder_), tile_size, num_rows, &ir_builder_); UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size), LastThunk(), ir_emitter_context_->llvm_module()); @@ -850,9 +863,11 @@ Status IrEmitterUnnested::EmitColumnReduction( &ir_builder_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( - llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_, - "output_element_address"); + llvm::Value* output_address = + GetIrArray(*output, *output) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), + &ir_builder_, "output_element_address"); return EmitAtomicOperationForNestedComputation( *reducer, output_address, partial_reduction_result_address); }; @@ -1081,16 +1096,25 @@ Status IrEmitterUnnested::EmitRowReduction( // from the warp. llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &ir_builder_); + int bit_width = llvm_ir::GetSizeInBits(element_ir_type); + // bitcast cannot be applied to aggregate types (even packed ones), so we + // instead bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() + ? ir_builder_.getIntNTy(bit_width) + : element_ir_type; for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - partial_reduction_result_address, "partial_reduction_result"); + ir_builder_.CreateBitCast(partial_reduction_result_address, + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); ir_builder_.CreateStore( EmitShuffleDown(partial_reduction_result, ir_builder_.getInt32(shuffle_distance), &ir_builder_), - result_from_other_lane); + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducer, {partial_reduction_result_address, result_from_other_lane}, partial_reduction_result_address)); @@ -1107,9 +1131,11 @@ Status IrEmitterUnnested::EmitRowReduction( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( - llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), &ir_builder_, - "output_element_address"); + llvm::Value* output_address = + GetIrArray(*output, *output) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), + &ir_builder_, "output_element_address"); return EmitAtomicOperationForNestedComputation( *reducer, output_address, partial_reduction_result_address); }; @@ -1249,11 +1275,12 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { MakeUnique(std::move(thunks), reduce)); return EmitReductionToVector( reduce, input->shape(), - [this, input](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*input).EmitReadArrayElement(index, &ir_builder_); + [&](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*input, *reduce) + .EmitReadArrayElement(index, &ir_builder_); }, - [this, init_value](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value) + [&](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*init_value, *reduce) .EmitReadArrayElement(index, &ir_builder_); }, dimensions_to_reduce, reducer); @@ -1417,7 +1444,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); } }; - llvm_ir::IrArray operand_array(GetIrArray(*operand)); + llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); ir_builder_.CreateStore(operand_data, selected_value_address); @@ -1470,9 +1497,10 @@ Status IrEmitterUnnested::HandleSelectAndScatter( ir_builder_.CreateLoad(selected_index_address_slot)); } llvm::Value* source_value_address = - GetIrArray(*source).EmitArrayElementAddress(source_index, &ir_builder_); + GetIrArray(*source, *select_and_scatter) + .EmitArrayElementAddress(source_index, &ir_builder_); llvm::Value* output_value_address = - GetIrArray(*select_and_scatter) + GetIrArray(*select_and_scatter, *select_and_scatter) .EmitArrayElementAddress(selected_index, &ir_builder_); return EmitAtomicOperationForNestedComputation( *select_and_scatter->scatter(), output_value_address, @@ -1749,7 +1777,7 @@ Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, return EmitTargetElementLoopInThunk( *hlo, [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value) + return GetIrArray(*init_value, *hlo) .EmitReadArrayElement(index, &ir_builder_); }, thunk); @@ -1850,7 +1878,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); if (!hlo.IsMultiOutputFusion()) { - return ParallelLoopEmitter(element_generator, GetIrArray(hlo), + return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), launch_dimensions, &ir_builder_) .EmitLoop(IrName(&hlo)); } @@ -1858,7 +1886,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( // For multiple outputs fusion, we need to emit each operand and the root. std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { - output_arrays.push_back(GetIrArray(hlo, {i})); + output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &ir_builder_) @@ -1869,7 +1897,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_, + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, module_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 817e95a31c546076364674fad63cdb54c3d0e147..059943d48cd34b0ac487b91c3f3079ee3f761229 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -34,7 +34,7 @@ limitations under the License. #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/CodeGen/CommandFlags.def" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -60,6 +60,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" namespace xla { namespace gpu { @@ -76,7 +77,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Since CUDA 9.0, all GPU versions are included in a single file const char* unified_libdevice_filename = "libdevice.10.bc"; std::vector unified_libdevice_files; - const tensorflow::Status status = + const tensorflow::Status status = tensorflow::Env::Default()->GetMatchingPaths( tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename), &unified_libdevice_files); @@ -488,9 +489,11 @@ StatusOr CompileToPtx(llvm::Module* module, string ptx; { - ScopedLoggingTimer compilation_timer( - "Compile module " + llvm_ir::AsString(module->getName()), - /*vlog_level=*/2); + tensorflow::port::Tracing::TraceMe annotation( + "Compiling IR", llvm_ir::AsString(module->getName()), + /*is_expensive=*/true); + XLA_SCOPED_LOGGING_TIMER("Compile module " + + llvm_ir::AsString(module->getName())); TF_ASSIGN_OR_RETURN( ptx, CompileModuleToPtx(module, compute_capability, hlo_module_config, libdevice_dir_path)); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 9274e16a455fc1a958cee5101b6a9ef7ce619347..c29fee0879c02021fdc23ac0e02ab398cf40f99e 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -49,8 +49,8 @@ HloInstruction* MaybePaddedAndSlicedInput( // applies positive padding and dilation. PaddingConfig padding_config = MakeNoPaddingConfig(input->shape().dimensions_size()); - for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.spatial_dimensions(i); + for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.input_spatial_dimensions(i); padding_config.mutable_dimensions(dim)->set_edge_padding_low( std::max(0LL, conv_window.dimensions(i).padding_low())); padding_config.mutable_dimensions(dim)->set_edge_padding_high( @@ -81,8 +81,8 @@ HloInstruction* MaybePaddedAndSlicedInput( std::vector limit_indices(input->shape().dimensions().begin(), input->shape().dimensions().end()); std::vector strides(input->shape().dimensions_size(), 1); - for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.spatial_dimensions(i); + for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.input_spatial_dimensions(i); // If dimension "dim" has negative padding, increase the start index or // decrement the limit index by the amount of negative padding. start_indices[dim] += @@ -117,8 +117,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) { padding_config.add_dimensions(); } - for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.spatial_dimensions(i); + for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.kernel_spatial_dimensions(i); padding_config.mutable_dimensions(dim)->set_interior_padding( conv_window.dimensions(i).window_dilation() - 1); } @@ -202,8 +202,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // ABCD0 = Pad(ABCD, padding_high=1) // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* transpose = backward_conv->fused_expression_root(); - HloInstruction* forward_conv = transpose->mutable_operand(0); + HloInstruction* forward_conv = backward_conv->fused_expression_root(); HloInstruction* input = backward_conv->mutable_operand(0); Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); @@ -229,7 +228,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // later. Therefore, the amount of new padding (low or high) is the minimum // of the amount of old padding low and old padding high. int64 new_conv_padding = std::min(padding_low, padding_high); - int64 dim = backward_conv_dnums.spatial_dimensions(i); + int64 dim = backward_conv_dnums.input_spatial_dimensions(i); input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( padding_low - new_conv_padding); input_padding_config.mutable_dimensions(dim)->set_edge_padding_high( @@ -269,19 +268,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( .ConsumeValueOrDie(), padded_input, output, new_forward_conv_window, forward_conv_dnums)); - HloInstruction* new_transpose = - computation->AddInstruction(HloInstruction::CreateTranspose( - ShapeInference::InferTransposeShape(new_forward_conv->shape(), - transpose->dimensions()) - .ConsumeValueOrDie(), - new_forward_conv, transpose->dimensions())); - - // Fuse the new forward convolution and the new transpose to the new backward - // convolution. + // Fuse the new forward convolution to the new backward convolution. HloInstruction* new_backward_conv = computation->CreateFusionInstructionForBackwardConvolution( - {new_transpose, new_forward_conv}, - HloInstruction::FusionKind::kConvBackwardFilter, + {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter, new_backward_conv_window, backward_conv_dnums); VLOG(1) << "Canonicalizing backward filter conv"; @@ -369,12 +359,11 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( std::vector limit_indices( new_backward_conv->shape().dimensions().begin(), new_backward_conv->shape().dimensions().end()); - std::vector strides(new_backward_conv->shape().dimensions_size(), - 1LL); + std::vector strides(new_backward_conv->shape().dimensions_size(), 1LL); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64 padding_low = backward_conv->window().dimensions(i).padding_low(); int64 padding_high = backward_conv->window().dimensions(i).padding_high(); - int64 dim = backward_conv_dnums.spatial_dimensions(i); + int64 dim = backward_conv_dnums.output_spatial_dimensions(i); if (padding_low > padding_high) { // If the amount of low padding (of the old backward convolution) is // larger, we internally pad the low end of the activations and slice diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index d0d2deee24848184278e3e51dcaa3bb673b5fadc..6cf280df05496716a0780d61ded92efd9982734c 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -44,37 +44,41 @@ std::ostream& operator<<(std::ostream& out, // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( - const Shape& shape, const se::DeviceDescription& device_desc, - PartitionStrategy partition_strategy) { - int64 warp_size = device_desc.threads_per_warp(); - + const Shape& shape, const se::DeviceDescription& device_desc) { int64 num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); } - // Calculate the number of threads per block. - // Initialize threads_per_block as the threads-per-block limit. - int64 threads_per_block = device_desc.threads_per_block_limit(); - VLOG(2) << "Initial # of threads per block = " << threads_per_block; - - if (partition_strategy == PartitionStrategy::kLatency) { - // Limit the thread count to allow maximum number of registers per thread. - // TODO(b/28560520): We don't have to assume the emitted kernel will use up - // all the registers. We could use ptxas to examine the actual number of - // register used, and set the thread count accordingly. - int64 threads_per_block_limit_due_to_registers = - device_desc.registers_per_core_limit() / - device_desc.registers_per_thread_limit(); - CHECK_NE(0, threads_per_block_limit_due_to_registers); - if (threads_per_block_limit_due_to_registers < threads_per_block) { - threads_per_block = - // Make `threads_per_block` a multiple of warp size to use GPU - // efficiently. - warp_size * - std::max(1LL, threads_per_block_limit_due_to_registers / warp_size); - VLOG(2) << "Update # of threads per block due to register pressure = " - << threads_per_block; + // Since we don't do any inter-warp communication, we're free to choose any + // block size we want, subject to hardware constraints. We choose the + // smallest block size that allows the GPU to reach full occupancy (assuming + // the kernel uses sufficiently few registers). This gives us max performance + // when the kernel uses few registers, and lets us scale down gracefully as + // the kernel uses more registers. + // + // Specifically, we choose the number of threads per block such that + // + // * = + + auto threads_per_core = device_desc.threads_per_core_limit(); + auto blocks_per_core = device_desc.blocks_per_core_limit(); + int64 threads_per_block; + if (threads_per_core != 0 && blocks_per_core != 0) { + threads_per_block = device_desc.threads_per_core_limit() / + device_desc.blocks_per_core_limit(); + } else { + static std::atomic log_count{0}; + if (log_count.fetch_add(1) < 8) { + LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " + "without full information about its capabilities. " + "StreamExecutor's PopulateDeviceDescription should be " + "updated for this device."; + } + threads_per_block = device_desc.threads_per_warp(); + if (threads_per_block == 0) { + // Fall back to *something* if we can't even get num threads per warp. + threads_per_block = 32; } } @@ -84,8 +88,6 @@ LaunchDimensions CalculateLaunchDimensions( << threads_per_block << ") because the latter is smaller."; } - // Calculate the block count. We copy the strategy used by Eigen: - // eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h int64 block_count = CeilOfRatio(num_elements, threads_per_block); VLOG(2) << tensorflow::strings::Printf( "Initialized the block count to ceil(# of elements / threads per " diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 8f7fce884acc93fd39510ad0826b819a6d9731a7..0bf463a6ef95d5a32784838c08ad239752fd1acf 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -30,14 +30,6 @@ limitations under the License. namespace xla { namespace gpu { -enum class PartitionStrategy { - // Optimized for latency by allowing maximum number of registers per thread. - kLatency, - // Optimized for throughput. This may limit registers per thread and cause - // longer latency. - kThroughput -}; - // Encapsulates the launch dimensions of a kernel, e.g., the block count and the // number of threads per block. class LaunchDimensions { @@ -66,8 +58,7 @@ std::ostream& operator<<(std::ostream& out, LaunchDimensions CalculateLaunchDimensions( const Shape& shape, - const perftools::gputools::DeviceDescription& device_desc, - PartitionStrategy partition_strategy = PartitionStrategy::kLatency); + const perftools::gputools::DeviceDescription& device_desc); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 0ff27888ad72f8190400c22a9086d1965448662c..486ea7d7e1dad3f7f37d50565e176fbf567f5cc4 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -70,6 +70,19 @@ class Thunk { return tensorflow::Status::OK(); } + // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) + // before calling ExecuteOnStream(stream). If it returns true, it's the + // user's responsibility to wait for all activity on the GPU to finish before + // calling ExecuteOnStream. + // + // This value is not required to be constant for a given Thunk. For example, + // a Thunk that performs autotuning may return true for its first run and + // false thereafter. + virtual bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream* /*stream*/) { + return false; + } + // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. Stream argument must be non-null. diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 0d2412096abf7838b7b0e7617811c789f507a4a1..c21559af6d2e5dfb5aaf62afcdcaed514e0914c9 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -34,16 +34,14 @@ WhileThunk::WhileThunk( body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -tensorflow::Status WhileThunk::Initialize(const GpuExecutable& executable) { +Status WhileThunk::Initialize(const GpuExecutable& executable) { TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable)); TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status WhileThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { - +Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { perftools::gputools::DeviceMemoryBase condition_result_data = buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); @@ -55,9 +53,11 @@ tensorflow::Status WhileThunk::ExecuteOnStream( // Copy the result of condition computation and break the loop if 'false'. bool condition_result; stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); - if (!stream->BlockHostUntilDone()) { + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { return InternalError( - "Failed to complete all kernels launched on stream %p", stream); + "Failed to complete all kernels launched on stream %p: %s", stream, + block_status.error_message().c_str()); } if (!condition_result) { @@ -68,7 +68,7 @@ tensorflow::Status WhileThunk::ExecuteOnStream( TF_RETURN_IF_ERROR( body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 95ed5497cea4fa3ba5dcdc6762cbd53cec88339a..4c9f45de9e42494df58706d0a4a3eb0c4220b8b8 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -45,10 +45,9 @@ class WhileThunk : public Thunk { WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status Initialize(const GpuExecutable& executable) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: const BufferAllocation::Slice condition_result_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 44188473d39088923c67216facab472a4e4ee09f..f16daa0b5481474e754c880ead1945297ca50168 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -17,9 +17,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -33,8 +36,6 @@ class WhileTransformerTest : public HloTestBase { : module_(CreateNewModule()), induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), data_shape_(ShapeUtil::MakeShape(F32, {8})), - loop_state_shape_(ShapeUtil::MakeTupleShape( - {induction_variable_shape_, data_shape_})), condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} std::unique_ptr BuildConditionComputation( @@ -42,8 +43,8 @@ class WhileTransformerTest : public HloTestBase { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(limit))); - auto loop_state = builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( + 0, GetLoopStateShape(tuple_index), "loop_state")); auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( limit_const->shape(), loop_state, tuple_index)); @@ -58,8 +59,8 @@ class WhileTransformerTest : public HloTestBase { const int64 increment) { auto builder = HloComputation::Builder(TestName() + ".Body"); // Create param instruction to access loop state. - auto loop_state = builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( + 0, GetLoopStateShape(ind_var_tuple_index), "loop_state")); // Update the induction variable GTE(ind_var_tuple_index). auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( @@ -73,7 +74,7 @@ class WhileTransformerTest : public HloTestBase { data_shape_, loop_state, data_tuple_index)); // Use 'induction_variable' in computation with no path to output tuple. auto update = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); + HloInstruction::CreateBroadcast(data_shape_, induction_variable, {})); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); // Create output Tuple. @@ -98,8 +99,9 @@ class WhileTransformerTest : public HloTestBase { HloInstruction::CreateTuple({induction_var_init, data_init})) : builder.AddInstruction( HloInstruction::CreateTuple({data_init, induction_var_init})); - auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( - loop_state_shape_, condition, body, loop_state_init)); + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(GetLoopStateShape(ind_var_tuple_index), + condition, body, loop_state_init)); module_->AddEntryComputation(builder.Build()); return while_hlo; } @@ -115,18 +117,34 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { + HloVerifier verifier([](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + }); + TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; - EXPECT_IS_OK(copy_insertion.Run(module_.get()).status()); + TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); + } + + Shape GetLoopStateShape(const int64 ind_var_tuple_index) { + if (ind_var_tuple_index == 0) { + return ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_}); + } else { + return ShapeUtil::MakeTupleShape( + {data_shape_, induction_variable_shape_}); + } } std::unique_ptr module_; Shape induction_variable_shape_; Shape data_shape_; - Shape loop_state_shape_; Shape condition_result_shape_; }; -TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { // Build computation with induction variable at tuple element 0. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -137,13 +155,16 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_TRUE(result.ok()); + TF_ASSERT_OK(result.status()); // Check results. EXPECT_THAT(result.ConsumeValueOrDie(), Eq(std::tuple(0, 10, 1))); } -TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { // Build computation with induction variable at tuple element 1. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); @@ -154,13 +175,16 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_TRUE(result.ok()); + TF_ASSERT_OK(result.status()); // Check results. EXPECT_THAT(result.ConsumeValueOrDie(), Eq(std::tuple(0, 10, 1))); } -TEST_F(WhileTransformerTest, InvalidLoopLimit) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { // Build computation with invalid loop limit. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); @@ -176,7 +200,10 @@ TEST_F(WhileTransformerTest, InvalidLoopLimit) { HasSubstr("Loop start must be less than loop limit.")); } -TEST_F(WhileTransformerTest, InvalidLoopIncrement) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { // Build computation with invalid loop increment. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 049e8d80d80c835bca4a4d38592564ba82a3ecf9..05017008e2ddbe0b9e78d06275fdec5d08d94bfa 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -108,8 +108,11 @@ std::unique_ptr MakeBigGraph() { HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0)); auto clamp = builder.AddInstruction(HloInstruction::CreateTernary( vshape, HloOpcode::kClamp, copy, param_v1, param_v2)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(vshape, HloOpcode::kDot, clamp, param_v0)); + HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({dot, param_s, clamp})); auto scalar = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 17b926c8748e45b55f380e7595711b9e7a748f64..387b649a731ebcbfd8307807469f39f22d192b06 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -259,8 +259,11 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); // The buffer for dot is the output, and it cannot be shared with the buffer // for mul, since dot isn't elementwise. @@ -292,8 +295,11 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); @@ -327,10 +333,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); // The buffer for dot1 is the output. No buffers can be shared. The buffer // for mul is freed before the end, since it's no longer used after dot0 @@ -365,10 +374,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 79493c4112804f8454d200f3f83aa85d718f0d0a..5d0cfba1fc8ab255c228c671fee641e9302f5ec6 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -118,6 +118,9 @@ message HloInstructionProto { // Shape of outfeed request. xla.Shape outfeed_shape = 29; + + // Describes the dimension numbers used for a dot operation + xla.DotDimensionNumbers dot_dimension_numbers = 30; } // Serialization of HloComputation. @@ -250,7 +253,3 @@ message HloProto { HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } - -message HloProtos { - repeated HloProto hlo_protos = 1; -} diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 6f8099475146e6bbcfb61d2e5a91a7a6f9e63e58..6d2a3aa5b531650a658502531e050702ffbd3760 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -144,8 +144,10 @@ class BufferValueMap { // Move the given value into the given buffer. void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) { BufferNumber old_buffer_number = value_to_buffer_number_.at(&value); - buffers_.at(old_buffer_number).erase(&value); - if (buffers_.at(old_buffer_number).empty()) { + tensorflow::gtl::FlatSet& old_value_set = + buffers_.at(old_buffer_number); + old_value_set.erase(&value); + if (old_value_set.empty()) { buffers_.erase(old_buffer_number); } @@ -175,7 +177,7 @@ class BufferValueMap { // Value is init of a while (use is while). std::vector aliased_buffers; for (const HloUse& use : value.uses()) { - VLOG(1) << "use of value " << value.ToShortString() << ": " << use; + VLOG(2) << "use of value " << value.ToShortString() << ": " << use; if (use.instruction->opcode() == HloOpcode::kWhile) { // Determine the while value that this shares a buffer with. const HloValue& while_value = @@ -411,7 +413,7 @@ string HloAliasAnalysis::ToString() const { /* static */ StatusOr> HloAliasAnalysis::Run( HloModule* module) { - VLOG(1) << "HloAliasAnalysis::Run on module " << module->name(); + VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); @@ -444,7 +446,7 @@ StatusOr> HloAliasAnalysis::Run( TF_DCHECK_OK(alias_analysis->Verify()); - XLA_VLOG_LINES(1, alias_analysis->ToString()); + XLA_VLOG_LINES(2, alias_analysis->ToString()); return std::move(alias_analysis); } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 8f595b45e9832376c4ef881065207f70d2501bee..014a851c96ed1d530cfd5fa4e854cf1df45fc4d0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -176,10 +176,6 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) { return false; } - if (instruction->HasSideEffect()) { - return false; - } - return true; } @@ -207,7 +203,8 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( worklist.pop(); if (removed.count(item) != 0 || item->user_count() != 0 || - item == root_instruction() || !IsRemovable(item)) { + item == root_instruction() || !IsRemovable(item) || + item->HasSideEffect()) { continue; } for (int i = 0; i < item->operand_count(); ++i) { @@ -385,11 +382,6 @@ string HloComputation::ToString(int nested_level, /*include_metadata=*/true, /*include_large_constants=*/include_large_constants) << "\n"; - if (instruction->opcode() == HloOpcode::kFusion) { - s << instruction->fused_instructions_computation()->ToString( - nested_level + 1, include_large_constants) - << "\n"; - } } for (int i = 0; i < nested_level; i++) { s << " "; @@ -412,16 +404,18 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( HloModule* module, const HloComputationProto& proto, - tensorflow::gtl::FlatMap* computation_map, + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation, HloInstruction* fusion_instruction) { std::vector> instructions; tensorflow::gtl::FlatMap instruction_map; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr instruction, - HloInstruction::CreateFromProto(module, instruction_proto, - instruction_map, computation_map)); + TF_ASSIGN_OR_RETURN(std::unique_ptr instruction, + HloInstruction::CreateFromProto( + module, instruction_proto, instruction_map, + computation_map, add_fused_computation)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index c9782cc981ef067058a5b14d3d1fffdd3eb6b49b..ccedda2a03c088b93883dd79a101c832497a937a 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -152,12 +152,16 @@ class HloComputation { // computation_map: a map from computation name to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. - // fusion_instruction: if non-null then the newly created computation will be - // constructed as a fused computation with this instruction as its fusion - // parent. + // add_fused_computation: A function to call to add a fused + // computation. Used only when the instruction is a fusion instruction. + // fusion_instruction: if non-null then the newly created computation will + // be constructed as a fused computation with this instruction as its + // fusion parent. static StatusOr> CreateFromProto( HloModule* module, const HloComputationProto& proto, - tensorflow::gtl::FlatMap* computation_map, + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation, HloInstruction* fusion_instruction = nullptr); // Gets the instructions in this computation. @@ -309,11 +313,17 @@ class HloComputation { replacements, HloModule* module = nullptr, const string& suffix = "clone"); - // Returns true if the given instruction can be removed from the - // computation. Instructions such as parameters and send/receive instructions - // cannot be removed without violating invariants of the HLO computation or - // module with the exception of fusion computation. A parameter instruction - // is removable for a fusion computation. + // Returns true if the given instruction can be removed from the computation. + // Parameter instructions cannot be removed without violating invariants of + // the HLO computation with the exception of fusion computation. A parameter + // instruction is removable for a fusion computation. + // + // Note that IsRemovable() is a necessariy condition to remove an instruction + // rather than a sufficient condition. For example, instructions with + // side-effect (e.g., Send, Infeed) may be removed from a computation, but the + // transformation must guarantee the invariants relevant to the instructions + // still hold (e.g., Send and Recv must be removed together to make each + // channel complete). bool IsRemovable(const HloInstruction* instruction); // Returns true if this computation has a side effect. A computation has a @@ -326,6 +336,9 @@ class HloComputation { // Returns the owning fusion instruction, or nullptr if this is not a fusion // computation. HloInstruction* FusionInstruction() const { return fusion_instruction_; } + void SetFusionInstruction(HloInstruction* fusion_instruction) { + fusion_instruction_ = fusion_instruction; + } private: explicit HloComputation( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 17ba2b673ac2db2060f720139bdc52ef1e72c98a..b933695b823871c6c0174da6d6f99e618219442a 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -22,13 +22,14 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" namespace xla { constexpr char HloCostAnalysis::kFlopsKey[]; constexpr char HloCostAnalysis::kTranscendentalsKey[]; constexpr char HloCostAnalysis::kBytesAccessedKey[]; -constexpr char HloCostAnalysis::kSecondsKey[]; +constexpr char HloCostAnalysis::kOptimalSecondsKey[]; HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size) : HloCostAnalysis(shape_size, {}) {} @@ -60,16 +61,16 @@ Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) { if (current_should_compute_bottleneck_time_) { // Compute the time as the time of the bottleneck, i.e. the slowest property // given the per-second rate of each property. - float max_seconds = 0.0f; + float optimal_seconds = 0.0f; for (const auto& property : current_properties_) { - if (property.first != kSecondsKey) { - max_seconds = std::max( - max_seconds, + if (property.first != kOptimalSecondsKey) { + optimal_seconds = std::max( + optimal_seconds, property.second / GetProperty(property.first, per_second_rates_, INFINITY)); } } - current_properties_[kSecondsKey] = max_seconds; + current_properties_[kOptimalSecondsKey] = optimal_seconds; } TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second); @@ -200,10 +201,11 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); // Count of elements along the reduction dimension (last dimension for the // rhs). - int64 reduction_width = lhs_shape.dimensions(ShapeUtil::Rank(lhs_shape) - 1); - + int64 reduction_width = + lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); // First divide by reduction width before multiplying by rhs elements to avoid // overflow. int64 fma_count; @@ -337,10 +339,18 @@ Status HloCostAnalysis::HandleSend(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleSendDone(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleRecv(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleReshape(const HloInstruction*) { return Status::OK(); } @@ -388,7 +398,14 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { // // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. - current_properties_[kFlopsKey] = ShapeUtil::ElementsIn(crs->shape()); + double flops = 0.0; + ShapeUtil::ForEachSubshape( + crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); + current_properties_[kFlopsKey] = flops; return Status::OK(); } @@ -472,6 +489,25 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { return Status::OK(); } +Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { + // Compute the cost of the true and false computations and take the maximum + // from those for each property. + TF_ASSIGN_OR_RETURN(const Properties true_computation_properties, + ProcessSubcomputation(conditional->true_computation())); + TF_ASSIGN_OR_RETURN(const Properties false_computation_properties, + ProcessSubcomputation(conditional->false_computation())); + current_properties_ = true_computation_properties; + for (const auto& property : false_computation_properties) { + if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) { + current_properties_[property.first] = + std::max(current_properties_[property.first], property.second); + } + } + current_should_compute_bottleneck_time_ = false; + + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } @@ -488,8 +524,8 @@ float HloCostAnalysis::bytes_accessed() const { return GetProperty(kBytesAccessedKey, properties_sum_); } -float HloCostAnalysis::seconds() const { - return GetProperty(kSecondsKey, properties_sum_); +float HloCostAnalysis::optimal_seconds() const { + return GetProperty(kOptimalSecondsKey, properties_sum_); } int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const { @@ -504,8 +540,8 @@ int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const { return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_); } -float HloCostAnalysis::seconds(const HloInstruction& hlo) const { - return GetPropertyForHlo(hlo, kSecondsKey, hlo_properties_); +float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { + return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); } StatusOr HloCostAnalysis::ProcessSubcomputation( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 8074868e375541e424dbe17de8a3038880e41927..fade19522cf0c30eab037aa355de1f9203f80014 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -42,7 +42,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { static constexpr char kFlopsKey[] = "flops"; static constexpr char kTranscendentalsKey[] = "transcendentals"; static constexpr char kBytesAccessedKey[] = "bytes accessed"; - static constexpr char kSecondsKey[] = "seconds"; + static constexpr char kOptimalSecondsKey[] = "optimal_seconds"; // shape_size is a function which returns the size in bytes of the top-level // buffer of a shape. @@ -60,7 +60,9 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleReducePrecision(const HloInstruction* hlo) override; Status HandleConcatenate(const HloInstruction* concatenate) override; Status HandleSend(const HloInstruction* send) override; + Status HandleSendDone(const HloInstruction* send_done) override; Status HandleRecv(const HloInstruction* recv) override; + Status HandleRecvDone(const HloInstruction* recv_done) override; Status HandleConvert(const HloInstruction* convert) override; Status HandleCopy(const HloInstruction* copy) override; Status HandleDot(const HloInstruction* dot) override; @@ -95,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleReshape(const HloInstruction* reshape) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; + Status HandleConditional(const HloInstruction* conditional) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; @@ -116,14 +119,14 @@ class HloCostAnalysis : public ConstDfsHloVisitor { float flop_count() const; float transcendental_count() const; float bytes_accessed() const; - float seconds() const; + float optimal_seconds() const; // Returns the respective cost computed for a particular HLO instruction, or 0 // if the HLO was not found to have a cost in the analysis. int64 flop_count(const HloInstruction& hlo) const; int64 transcendental_count(const HloInstruction& hlo) const; int64 bytes_accessed(const HloInstruction& hlo) const; - float seconds(const HloInstruction& hlo) const; + float optimal_seconds(const HloInstruction& hlo) const; const Properties& properties() const { return properties_sum_; } const float property(const string& key) const { diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 0eaa21ef254e3461baaaca57503ab24ce35ac929..3b289c240a45e8f3df8156ed89e879da2132d01a 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -389,7 +389,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) { static_assert(bytes_accessed == 64, ""); EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed); - EXPECT_EQ(fusion_analysis.seconds(), 1 << i); + EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i); } } diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 7c4626e78a3e84c9723a9f8e39d56614c4fa25ce..3601a790c4428ee39c264b217a4b9a991ad8456c 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -79,12 +79,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { // Test that two identical constants with different layouts are commoned if // the pass is not layout sensitive. auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{0, 1}))); - auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{1, 0}))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -111,12 +111,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { // Test that two identical constants with different layouts are *not* commoned // if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{0, 1}))); - auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{1, 0}))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 92261bce6270e3c37165c10ed804d036d2abb984..2a335843f507e2071807245d4dd256e1ec6f08c8 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -75,11 +75,43 @@ HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, std::forward_as_tuple(value_id, instruction, index, is_phi)); CHECK(emplaced.second); + VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString(); + return &emplaced.first->second; } -void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) { - values_.erase(value_id); +void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { + HloValue& value = values_.at(value_id); + VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")"; + + value_ids_to_delete_.push_back(value_id); +} + +void HloDataflowAnalysis::DeleteMarkedValues() { +#ifndef NDEBUG + // Verify that no marked-for-deletion values are in any of the value sets. + tensorflow::gtl::FlatSet id_set(value_ids_to_delete_.begin(), + value_ids_to_delete_.end()); + for (const auto& pair : value_sets_) { + const HloInstruction* instruction = pair.first; + const InstructionValueSet& instruction_value_set = pair.second; + for (const auto& index_value_set : instruction_value_set) { + const HloValueSet& value_set = index_value_set.second; + for (const HloValue* value : value_set.values()) { + DCHECK(!ContainsKey(id_set, value->id())) + << "Value " << value->ToShortString() + << " marked for deletion, but still exists in value set for " + "instruction " + << instruction->name(); + } + } + } +#endif + + for (HloValue::Id value_id : value_ids_to_delete_) { + values_.erase(value_id); + } + value_ids_to_delete_.clear(); } string HloDataflowAnalysis::ToString() const { @@ -121,6 +153,7 @@ bool HloDataflowAnalysis::Phi( HloInstruction* instruction, tensorflow::gtl::ArraySlice inputs) { CHECK(ssa_form_); + VLOG(4) << "Phi(" << instruction->name() << ")"; for (const InstructionValueSet* input : inputs) { DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); @@ -183,7 +216,7 @@ bool HloDataflowAnalysis::Phi( } else if (current_value != &new_value) { if (current_value_defined_here) { // Remove the existing phi. - DeleteHloValue(current_value->id()); + MarkValueForDeletion(current_value->id()); } value_set.Clear(); value_set.AddValue(&new_value); @@ -193,7 +226,8 @@ bool HloDataflowAnalysis::Phi( // Multiple distinct values reach this point. A phi value is // necessary. CHECK_GT(input_value_ids.size(), 1); - if (current_value == nullptr || !current_value->is_phi()) { + if (current_value == nullptr || + !(current_value->is_phi() && current_value_defined_here)) { value_set.Clear(); value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true)); changed = true; @@ -242,6 +276,51 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { return false; } +bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { + CHECK_EQ(send->opcode(), HloOpcode::kSend); + bool changed = false; + // Send forwards the operand value to the output tuple at {0}. + for (auto& pair : GetInstructionValueSet(send->operand(0))) { + const ShapeIndex& operand_index = pair.first; + const HloValueSet& operand_value_set = pair.second; + + ShapeIndex index = {0}; + for (int64 i : operand_index) { + index.push_back(i); + } + + HloValueSet& value_set = GetValueSet(send, index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) { + CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone); + bool changed = false; + // RecvDone forwards the operand value at {0} to the output. + for (auto& pair : GetInstructionValueSet(recv_done)) { + ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + + ShapeIndex operand_index = {0}; + for (int64 i : index) { + operand_index.push_back(i); + } + + const HloValueSet& operand_value_set = + GetValueSet(recv_done->operand(0), operand_index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { CHECK_EQ(call->opcode(), HloOpcode::kCall); InstructionValueSet& value_set = GetInstructionValueSet(call); @@ -254,6 +333,21 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { return false; } +bool HloDataflowAnalysis::UpdateConditionalValueSet( + HloInstruction* conditional) { + CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); + std::vector inputs = { + &GetInstructionValueSet( + conditional->true_computation()->root_instruction()), + &GetInstructionValueSet( + conditional->false_computation()->root_instruction())}; + // A phi-node is not defined for a kConditional instruction even though it + // represents a join point. This is because the current approach is to define + // a phi-node only for kWhile to account for the dataflow through back-edges + // and deal with the ambiguity in other cases. + return GetInstructionValueSet(conditional).AssignUnionOf(inputs); +} + bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { CHECK_EQ(copy->opcode(), HloOpcode::kCopy); bool changed = false; @@ -315,7 +409,7 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { CHECK_EQ(call_graph_node.context(), CallContext::kSequential); std::vector inputs; - bool called_from_while = false; + bool need_phi = false; for (const CallSite& callsite : call_graph_node.caller_callsites()) { if (callsite.instruction()->opcode() == HloOpcode::kCall) { // The operand values of a call instruction are forwarded to the @@ -337,14 +431,32 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { inputs.push_back(&GetInstructionValueSet( callsite.instruction()->while_body()->root_instruction())); } - called_from_while = true; + need_phi = true; + } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) { + CHECK_EQ(parameter->parameter_number(), 0); + auto conditional = callsite.instruction(); + // Conditional has 3 operands. Operand 0 is the predicate, operand 1 is + // the argument to the true computation and operand 2 is the argument to + // the false computation. + // + // If the parameter belongs to conditional's true computation, then + // operand 1 is forwarded to this parameter instruction. If the parameter + // belongs to conditional's false computation, then operand 2 is forwarded + // to this parameter instruction. + if (parameter->parent() == conditional->true_computation()) { + inputs.push_back(&GetInstructionValueSet(conditional->operand(1))); + } else { + CHECK_EQ(parameter->parent(), conditional->false_computation()); + inputs.push_back(&GetInstructionValueSet(conditional->operand(2))); + } + need_phi = true; } else { LOG(FATAL) << "CallContext::kSequential computations should only be " - "called from call or while instructions"; + "called from call, while, or conditional instructions"; } } - if (ssa_form_ && called_from_while) { + if (ssa_form_ && need_phi) { return Phi(parameter, inputs); } else { return GetInstructionValueSet(parameter).AssignUnionOf(inputs); @@ -429,6 +541,12 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateCallValueSet(instruction); case HloOpcode::kWhile: return UpdateWhileValueSet(instruction); + case HloOpcode::kSend: + return UpdateSendValueSet(instruction); + case HloOpcode::kRecvDone: + return UpdateRecvDoneValueSet(instruction); + case HloOpcode::kConditional: + return UpdateConditionalValueSet(instruction); default: // Instruction does not forward HloValues (it defines all values in its // output). No update is necessary. @@ -436,11 +554,13 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( } } -void HloDataflowAnalysis::UpdateInstructionsAndPropagate( - tensorflow::gtl::ArraySlice instructions) { +void HloDataflowAnalysis::Propagate() { std::queue worklist; - for (HloInstruction* instruction : instructions) { - worklist.push(instruction); + + for (HloComputation* computation : module_->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + worklist.push(instruction); + } } while (!worklist.empty()) { @@ -465,13 +585,31 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate( // If user sequentially calls a computation, then the respective // parameter(s) of the computation need to be updated. - for (HloComputation* called_computation : user->called_computations()) { - const CallGraphNode& call_graph_node = - call_graph_->GetNode(called_computation); - if (call_graph_node.context() == CallContext::kSequential) { - for (int64 operand_number : user->OperandIndices(instruction)) { - worklist.push( - called_computation->parameter_instruction(operand_number)); + if (user->opcode() == HloOpcode::kConditional) { + // If operand 0 is the use of instruction, then no parameters need to be + // updated, since that is the predicate of the conditional. + // If operand 1 is the use of instruction, then the true_computation's + // parameter need to be updated. + // If operand 2 is the use of instruction, then the false_computation's + // parameter need to be updated. + // + // Note that the same instruction can be used in both operand 1 and + // operand 2. + if (user->operand(1) == instruction) { + worklist.push(user->true_computation()->parameter_instruction(0)); + } + if (user->operand(2) == instruction) { + worklist.push(user->false_computation()->parameter_instruction(0)); + } + } else { + for (HloComputation* called_computation : user->called_computations()) { + const CallGraphNode& call_graph_node = + call_graph_->GetNode(called_computation); + if (call_graph_node.context() == CallContext::kSequential) { + for (int64 operand_number : user->OperandIndices(instruction)) { + worklist.push( + called_computation->parameter_instruction(operand_number)); + } } } } @@ -483,7 +621,8 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate( const CallGraphNode& call_graph_node = call_graph_->GetNode(instruction->parent()); for (const CallSite& callsite : call_graph_node.caller_callsites()) { - if (callsite.instruction()->opcode() == HloOpcode::kCall) { + if ((callsite.instruction()->opcode() == HloOpcode::kCall) || + (callsite.instruction()->opcode() == HloOpcode::kConditional)) { worklist.push(callsite.instruction()); } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { // Add the while itself, and the body and condition parameters. @@ -537,6 +676,12 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { GetValueSet(instruction, /*index=*/{}).AddValue(value); }; + // Lambda to set the value set at the given index of the output. + auto define_value_at = [this, &instruction](const ShapeIndex& index) { + HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false); + GetValueSet(instruction, index).AddValue(value); + }; + switch (instruction->opcode()) { case HloOpcode::kBitcast: if (bitcast_defines_value_) { @@ -545,6 +690,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { break; case HloOpcode::kWhile: case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kGetTupleElement: // These instructions define no values. The values in their output // flow from their operands or from cross computation dataflow. @@ -577,6 +723,16 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { // values flow from their operands. define_top_level_only(); break; + case HloOpcode::kRecvDone: + // RecvDone aliases its input tuple element {0}, therefore does not + // define any values. + break; + case HloOpcode::kSend: + // Send produces a tuple of {aliased operand, U32 context}, therefore + // only defines the top-level tuple and the tuple element at {1}. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{1}); + break; default: define_all_values(); break; @@ -597,20 +753,17 @@ StatusOr> HloDataflowAnalysis::Run( new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); + dataflow_analysis->Propagate(); - // Construct list of all instructions to initialize the worklist to propagate - // the data flow. For efficiency sort the instruction in post order so - // producers appear before consumers. - std::vector all_instructions; - for (const HloComputation* computation : module->MakeComputationPostOrder()) { - for (HloInstruction* instruction : - computation->MakeInstructionPostOrder()) { - all_instructions.push_back(instruction); - } - } - dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions); + // Delete all values marked for deletion. + dataflow_analysis->DeleteMarkedValues(); - // Add in positions to all values. + // Gather and set all non-definition positions of all values. Value deletion + // is rare, so just use a vector indexed by Value::Id rather than a map from + // Value::Id to positions. There should be very few holes in the vector, and + // lookup is faster. + std::vector> value_positions( + dataflow_analysis->next_value_id_); for (const HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { for (const auto& pair : @@ -619,13 +772,18 @@ StatusOr> HloDataflowAnalysis::Run( const HloValueSet& value_set = pair.second; for (const HloValue* value : value_set.values()) { if (value->defining_instruction() != instruction) { - dataflow_analysis->GetValue(value->id()) - .AddPosition(instruction, index); + value_positions[value->id()].push_back( + HloPosition{instruction, index}); } } } } } + for (auto& pair : dataflow_analysis->values_) { + HloValue::Id value_id = pair.first; + HloValue& value = pair.second; + value.SetPositionsAndComputeUses(value_positions[value_id]); + } // Construct vector of values. dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size()); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 207e553bf7fb62e19b9fa89eaf6bfb3234592c11..469620d01295f90e0c36a48cac9be47c12473a68 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -126,13 +126,16 @@ class HloDataflowAnalysis { HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); - // Delete the HloValue with the given ID. - void DeleteHloValue(HloValue::Id value_id); + // Mark the HloValue with the given ID for deletion. + void MarkValueForDeletion(HloValue::Id value_id); + + // Delete all HloValues marked for deletion. Should be called after + // propagation is complete. + void DeleteMarkedValues(); // Constructs and initializes the InstructionValueSets of all instructions to // contain exactly the HloValues defined by each instruction. These values can - // then propagated throughout the HLO graph by calling - // UpdateInstructionsAndPropagate. + // then propagated throughout the HLO graph by calling Propagate. Status InitializeInstructionValueSets(); // Updates the value set of the given instruction based on the values flowing @@ -143,17 +146,18 @@ class HloDataflowAnalysis { // the instruction value set changed. bool UpdateBitcastValueSet(HloInstruction* bitcast); bool UpdateCallValueSet(HloInstruction* call); + bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); + bool UpdateRecvDoneValueSet(HloInstruction* recv_done); bool UpdateSelectValueSet(HloInstruction* select); + bool UpdateSendValueSet(HloInstruction* send); bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); - // Update the value sets of the given instructions and propagate the - // changes to fixed point. - void UpdateInstructionsAndPropagate( - tensorflow::gtl::ArraySlice instructions); + // Propagate the dataflow through the module. + void Propagate(); // Return the result of the SSA Phi function applied to the given inputs at // the given instruction. If skip_top_level is true, then the top level of the @@ -189,6 +193,11 @@ class HloDataflowAnalysis { // A map from instruction to InstructionValueSet. std::unordered_map value_sets_; + // Values marked for deletion during construction. We don't delete them + // immediately because references to them may remain in ValueSets temporarily + // during propagation. After construction, these values are deleted. + std::vector value_ids_to_delete_; + // A vector containing all HloValues sorted by HloValue::Id. std::vector values_vector_; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4b8eb237a6712804657bb7b67cdde9a2d331bd11..e714b2567fd1b3eab607a19f0bb7e3288150dc64 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla { namespace { +using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; // Test is parameterized on a bool which is whether the dataflow analysis is @@ -77,11 +78,23 @@ class HloDataflowAnalysisTest : public HloTestBase, analysis_->GetValueDefinedAt(b), *analysis_); } + std::unique_ptr CreateR0F32UnaryOpComputation( + HloOpcode opcode) { + HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode)); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, opcode, param0)); + return builder.Build(); + } + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42}); + const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); }; TEST_P(HloDataflowAnalysisTest, BinaryOperation) { @@ -211,10 +224,10 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}}, HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}}, HloPosition{gte_out, {}})); - // Constant values should have no uses though one is live out. The positions - // where they appear as operands are on instructions which do not use the - // values (eg, Tuple). - EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); + // Constant values should have only a single use, which is the root of the + // computation. + EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(), + UnorderedElementsAre(HloUse{gte_out, 0, {0}})); EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); // The top-level tuple values are used in GTE instructions. @@ -274,12 +287,11 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); + UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { @@ -323,18 +335,17 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}}, + HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); + UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}}, + HloUse{add, 1, {}})); // The Add from the subcomputation is used as both operands of the Subtract. EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(), UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}})); EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); - EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { @@ -408,7 +419,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto outer_param1 = outer_builder.AddInstruction( HloInstruction::CreateParameter(1, scalar_shape_, "param1")); // Swizzle parameters. - outer_builder.AddInstruction(HloInstruction::CreateCall( + auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {outer_param1, outer_param0}, inner_computation)); HloComputation* outer_computation = module_->AddEmbeddedComputation(outer_builder.Build()); @@ -418,7 +429,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(2.0))); - builder.AddInstruction(HloInstruction::CreateCall( + auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); module_->AddEntryComputation(builder.Build()); @@ -431,10 +442,14 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { // Verify that the uses of the constants are properly swizzled by parameter // permutation in nested_call. - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); - EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}}, + HloUse{add, 1, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}}, + HloUse{add, 0, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); } @@ -469,7 +484,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); - body_builder.AddInstruction( + auto body_root = body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); @@ -496,8 +511,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_TRUE( - analysis.GetValueDefinedAt(cond_constant).live_out_of_computation()); EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module()); if (ssa_form) { @@ -517,14 +530,14 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_THAT( analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}})); + UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}}, + HloUse{xla_while, 0, {0}})); // Constant1 passes through the body and out of the module. EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) .live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); } else { // While instruction and subcomputation parameters should not define values @@ -538,7 +551,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } } @@ -915,9 +927,11 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { HloUse{select12, 1, {}})); // The two constant values just pass through the Selects and are not - // used. They are live out however. - EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); - EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); + // used except at the root. They are live out however. + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{select1234, 1, {0}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{select1234, 1, {0}})); EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); } @@ -1139,6 +1153,54 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module()); } +TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { + // Test that a Send forwards its operand to the output tuple at {0}. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto send = builder.AddInstruction( + HloInstruction::CreateSend(param, /*channel_id=*/0)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + module_->AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 4); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(param)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done)); + EXPECT_THAT(HloValuesAt(send, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(param))); +} + +TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { + // Test that a RecvDone forwards its operand tuple element at {0} to the + // output. + auto builder = HloComputation::Builder(TestName()); + auto recv = builder.AddInstruction( + HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + module_->AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 3); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done)); + EXPECT_THAT(HloValuesAt(recv_done), + UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0}))); + EXPECT_TRUE( + analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module()); +} + TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) { // A simple chain of elementwise operations. No values should interfere. // @@ -1270,7 +1332,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { auto entry = module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); - const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + RunAnalysis(ssa_form); SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param, xla_while}}); @@ -1281,12 +1343,6 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { SequentialHloOrdering ordering(module_.get(), sequence); - // 'add' is the body root even though later instructions follow in the order - // like 'dead_negate'. Only 'add' should be live out of the computation. - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); - EXPECT_FALSE( - analysis.GetValueDefinedAt(dead_negate).live_out_of_computation()); - // 'add' is live out of the body and will interfere with an later instructions // such as 'dead_constant' and 'dead_negate'. EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant)); @@ -1485,6 +1541,315 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log)); } +TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { + // Test conditional with identity computations in both true and false cases. + // + // true_computation(F32[] %true_param): + // return %true_param + // + // false_computation(F32[] %false_param): + // return %false_param + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // return Conditional(%pred, %constant1, true_computation, + // %constant2, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "true_param")); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "false_param")); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, constant1, true_computation, constant2, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + ElementsAre(HloUse{conditional, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + ElementsAre(HloUse{conditional, 2, {}})); + + EXPECT_EQ(analysis.values().size(), 3); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); +} + +TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { + // Test conditional with true and false computations taking a tuple operand. + // + // true_computation((F32[], F32[]) %true_param): + // %true_x = GetTupleElement(%true_param, 0) + // %true_y = GetTupleElement(%true_param, 1) + // return Add(%true_x, %true_y) + // + // false_computation((F32[], F32[]) %false_param): + // %false_x = GetTupleElement(%false_param, 0) + // %false_y = GetTupleElement(%false_param, 1) + // return Subtract(%false_x, %false_y) + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // %tuple_operand = Tuple(%constant1, %constant2) + // return Conditional(%pred, %tuple_operand, true_computation, + // %tuple_operand, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "true_param")); + auto true_x = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0)); + auto true_y = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1)); + auto add = true_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, true_x, true_y)); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "false_param")); + auto false_x = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0)); + auto false_y = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1)); + auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kSubtract, false_x, false_y)); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, tuple_operand, true_computation, tuple_operand, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_y), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_y), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {0}}, + HloUse{conditional, 2, {0}}, + HloUse{add, 0, {}}, HloUse{sub, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {1}}, + HloUse{conditional, 2, {1}}, + HloUse{add, 1, {}}, HloUse{sub, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(), + UnorderedElementsAre( + HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}}, + HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}}, + HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}})); + + EXPECT_EQ(analysis.values().size(), 6); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(sub))); +} + +TEST_P(HloDataflowAnalysisTest, NestedConditionals) { + // computation1(F32[] %param1): + // %ceil = Ceil(%param1) + // return %ceil + // + // computation2(F32[] %param2): + // %floor = Floor(%param2) + // return %floor + // + // computation3(F32[] %param3): + // %negate = Negate(%param3) + // return %negate + // + // inner_conditional((PRED, F32[], F32[]) %param_cond): + // %pred_cond = GetTupleElement(%param_cond, 0) + // %true_operand_cond = GetTupleElement(%param_cond, 1) + // %false_opearnd_cond = GetTupleElement(%param_cond, 2) + // return Conditional(%pred_cond, %true_operand_cond, computation1, + // %false_operand_cond, computation2) + // + // entry: + // %pred1 = Constant(true) + // %pred2 = Constant(false) + // %constant1 = Constant(1.1); + // %constant2 = Constant(2.2); + // %constant3 = Constant(3.3); + // return Conditional(%pred1, (%pred2, %constant1, %constant2), + // inner_conditional, %constant3, computation3) + + auto computation1 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kCeil)); + auto computation2 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kFloor)); + auto computation3 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kNegate)); + + // Build inner_conditional computation. + const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {}); + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {scalar_bool_shape, scalar_shape_, scalar_shape_}); + auto inner_builder = + HloComputation::Builder(TestName() + "_inner_conditional"); + auto param_cond = inner_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond")); + auto pred_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0)); + auto true_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1)); + auto false_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2)); + auto inner_conditional = + inner_builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred_cond, true_operand_cond, computation1, + false_operand_cond, computation2)); + auto inner_conditional_computation = + module_->AddEmbeddedComputation(inner_builder.Build()); + + // Build entry computation. + auto builder = HloComputation::Builder(TestName()); + auto pred1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto pred2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.2f))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(3.3f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({pred2, constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred1, tuple_operand, inner_conditional_computation, + constant3, computation3)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction())); + + auto computation1_param = computation1->parameter_instruction(0); + auto computation2_param = computation2->parameter_instruction(0); + auto computation3_param = computation3->parameter_instruction(0); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param), + analysis.GetValueDefinedAt(constant3)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond)); + EXPECT_EQ(analysis.GetUniqueValueAt(param_cond), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond), + analysis.GetValueDefinedAt(pred2)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_EQ(analysis.values().size(), 9); + EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT( + HloValuesAt(inner_conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()))); + EXPECT_THAT( + HloValuesAt(conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()), + analysis.GetValueDefinedAt(computation3->root_instruction()))); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index a4921232f5848dbe1789c4c641e2b0ba3c1848bb..1e5f0f797a13fd7e7ce1cc934387a274a74153bc 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -37,6 +37,9 @@ namespace xla { StatusOr HloDCE::Run(HloModule* module) { bool changed = false; + VLOG(2) << "Before dce:"; + XLA_VLOG_LINES(2, module->ToString()); + for (auto* computation : module->MakeNonfusionComputations()) { std::unordered_set live_instructions; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( @@ -52,12 +55,15 @@ StatusOr HloDCE::Run(HloModule* module) { for (auto* instruction : computation->instructions()) { if (instruction->user_count() == 0 && live_instructions.count(instruction) == 0 && - computation->IsRemovable(instruction)) { + computation->IsRemovable(instruction) && + !instruction->HasSideEffect()) { dead_roots.push_back(instruction); } } for (HloInstruction* dead_root : dead_roots) { + VLOG(1) << "Removing dead root " << dead_root->ToString() + << " and it's unused operands"; TF_RETURN_IF_ERROR( computation->RemoveInstructionAndUnusedOperands(dead_root)); changed = true; @@ -87,6 +93,9 @@ StatusOr HloDCE::Run(HloModule* module) { } } + VLOG(2) << "After dce:"; + XLA_VLOG_LINES(2, module->ToString()); + return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index d54b9a27087a42fd23eab0bd06e8deaca567312b..5a56607a665c4cbeb7b2572f182b88e890602968 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -70,6 +70,26 @@ TEST_F(HloDceTest, NoDeadCode) { EXPECT_EQ(3, computation->instruction_count()); } +TEST_F(HloDceTest, InstructionsWithSideEffect) { + // Verify that side-effect instructions (Send in this test) are not removed. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateSend(constant, /*channel_id=*/0)); + builder.AddInstruction(HloInstruction::CreateTuple({})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, computation->instruction_count()); + + HloDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(3, computation->instruction_count()); +} + TEST_F(HloDceTest, DeadParameters) { // Verify that dead parameters are not removed, but use of the dead parameters // are. diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..1773bb401d380031f6c860d295e76d2f62c9e5ff --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -0,0 +1,137 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace { + +HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) { + if (hlo->shape().element_type() != type) { + Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type); + hlo = hlo->parent()->AddInstruction( + HloInstruction::CreateConvert(shape, hlo)); + } + CHECK_EQ(hlo->shape().element_type(), type); + return hlo; +} + +bool HasOperandType(HloInstruction* hlo, PrimitiveType type) { + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == type) { + return true; + } + } + return false; +} + +} // namespace + +HloElementTypeConverter::HloElementTypeConverter( + PrimitiveType eliminate_type, PrimitiveType replace_with_type) + : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {} + +StatusOr HloElementTypeConverter::Run(HloModule* module) { + XLA_VLOG_LINES( + 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* hlo : computation->MakeInstructionPostOrder()) { + // These are ops where it does not make sense to convert them. + if (hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kConstant || + hlo->opcode() == HloOpcode::kTuple || + hlo->opcode() == HloOpcode::kConvert || + hlo->opcode() == HloOpcode::kGetTupleElement || + hlo->opcode() == HloOpcode::kInfeed || + hlo->opcode() == HloOpcode::kOutfeed) { + continue; + } + + // We cannot change a CustomCall since we have no way of adjusting the + // called binary to expect the updated type. + if (hlo->opcode() == HloOpcode::kCustomCall) { + continue; + } + + // These are ops with embedded computations where it suffices to convert + // the embedded computations instead of converting the ops themselves. + if (hlo->opcode() == HloOpcode::kWhile || + hlo->opcode() == HloOpcode::kCall || + hlo->opcode() == HloOpcode::kFusion || + hlo->opcode() == HloOpcode::kMap || + hlo->opcode() == HloOpcode::kReduce || + hlo->opcode() == HloOpcode::kReduceWindow || + hlo->opcode() == HloOpcode::kSelectAndScatter || + hlo->opcode() == HloOpcode::kConditional) { + continue; + } + TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); + + if (!HasOperandType(hlo, eliminate_type_)) { + // If this CHECK fires, then this was an instruction that does not take + // the elimination type as an operand but it does return it. This pass + // does not have a feature to change the output type in that case, so + // instead of silently failing to eliminate the type, it fails loudly. + TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_); + continue; + } + + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == eliminate_type_) { + operand = ToElementType(operand, replace_with_type_); + } + new_operands.push_back(operand); + } + + HloInstruction* new_hlo; + if (hlo->shape().element_type() == eliminate_type_) { + Shape shape = + ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + new_hlo = ToElementType(new_hlo, eliminate_type_); + } else { + new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( + hlo->shape(), new_operands, hlo->GetModule())); + } + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo)); + changed = true; + } + } + XLA_VLOG_LINES( + 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..2b109225d0b192e5c9e4f6d841377ffad8078dc2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass that eliminates certain element types as the input or output of ops by +// inserting Convert ops. This allows a backend to support an element type while +// only actually implementing the Convert op for that element type. This is +// generally not the fastest approach, but it works. +class HloElementTypeConverter : public HloPassInterface { + public: + // eliminate_type is the type to eliminate as the input or output of ops, + // using Convert ops to replace it with replace_with_type. + HloElementTypeConverter(PrimitiveType eliminate_type, + PrimitiveType replace_with_type); + + tensorflow::StringPiece name() const override { + return "element_type_converter"; + } + + // Returns the pass on the module and returns whether the module was modified. + StatusOr Run(HloModule* module) override; + + private: + PrimitiveType eliminate_type_; + PrimitiveType replace_with_type_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 88b77ccdd03eb129f81cfa1da430e882ea569df4..e693d167a1f96f65b894d07fb2c8f33e61ff8c49 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -335,9 +335,31 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + return ~elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> Status HandleNot(HloInstruction* not_) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { @@ -357,7 +379,24 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleNot(not_); } - Status HandleNegate(HloInstruction* negate) override { + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { + return NativeT(-type(elem_operand)); + })); + return Status::OK(); + } + + template ::value || + std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { return -elem_operand; @@ -365,6 +404,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleNegate(HloInstruction* negate) override { + return HandleNegate(negate); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -402,7 +445,26 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleMultiply(HloInstruction* multiply) override { + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return NativeT(type(lhs_elem) * type(rhs_elem)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + std::is_floating_point::value || + is_complex_t::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[multiply], ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { @@ -411,6 +473,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleMultiply(HloInstruction* multiply) override { + return HandleMultiply(multiply); + } + Status HandleSubtract(HloInstruction* subtract) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[subtract], @@ -516,9 +582,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleRemainder(remainder); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el & rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[and_], @@ -539,9 +616,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleAnd(and_); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el | rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> Status HandleOr(HloInstruction* or_) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[or_], @@ -645,7 +733,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value>::type* = nullptr> Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = - [](ReturnT low, ReturnT high, ReturnT value) { + [](ReturnT low, ReturnT value, ReturnT high) { return std::fmax(low, std::fmin(value, high)); }; TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], @@ -724,7 +812,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); const auto& dnums = conv->convolution_dimension_numbers(); - const int64 num_spatial_dims = dnums.spatial_dimensions_size(); + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); CHECK_GE(num_spatial_dims, 0); CHECK_EQ(window.dimensions_size(), num_spatial_dims); @@ -789,13 +878,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { // Find corresponding spatial dimension index for input (lhs). for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { // Spatial dimension number for input (lhs) and output. - const int64 spatial_dim = dnums.spatial_dimensions(ki); + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = + dnums.output_spatial_dimensions(ki); // Calculate lhs (input) index without taking base dilation into // account. const auto& window_dim = window.dimensions(ki); const int64 undilated_index = - out_index[spatial_dim] * window_dim.stride() - + out_index[output_spatial_dim] * window_dim.stride() - window_dim.padding_low() + rhs_spatial_index[ki] * window_dim.window_dilation(); // Skip if the lhs (input) index is to be dilated. @@ -804,23 +895,26 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } // Calculate the actual lhs (input) index after dilation. - lhs_index[spatial_dim] = + lhs_index[input_spatial_dim] = undilated_index / window_dim.base_dilation(); // Skip if input index is not in bound. - if (!(lhs_index[spatial_dim] >= 0 && - lhs_index[spatial_dim] < lhs_shape.dimensions(spatial_dim))) { + if (!(lhs_index[input_spatial_dim] >= 0 && + lhs_index[input_spatial_dim] < + lhs_shape.dimensions(input_spatial_dim))) { goto cnt; } rhs_index[dnums.kernel_spatial_dimensions(ki)] = - rhs_spatial_index[ki]; + window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]; } result_val += lhs_literal.Get(lhs_index) * rhs_literal.Get(rhs_index); } - cnt:; + cnt : {} } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); return result_val; @@ -1287,6 +1381,50 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], + ElementWiseUnaryOp(sin, [](ReturnT elem_operand) { + return std::sin(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + return InvalidArgument("Unsupported type for Sin"); + } + + Status HandleSin(HloInstruction* sin) override { + return HandleSin(sin); + } + + template ::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], + ElementWiseUnaryOp(cos, [](ReturnT elem_operand) { + return std::cos(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + return InvalidArgument("Unsupported type for Cos"); + } + + Status HandleCos(HloInstruction* cos) override { + return HandleCos(cos); + } + private: template StatusOr> DynamicSlice( @@ -1397,8 +1535,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const auto* rhs = instruction->operand(1); const auto* ehs = instruction->operand(2); - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. + // TODO(b/35950897, b/27796129): add DCHECK back once implicit + // broadcast is removed. if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { @@ -1450,6 +1588,10 @@ HloEvaluator::HloEvaluator() { typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); typed_visitors_[C64] = MakeUnique>(this); + + typed_visitors_[BF16] = MakeUnique([](HloInstruction*) { + return Unimplemented("HloEvaluator: unhandled primitive type: BF16."); + }); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); }); @@ -1561,6 +1703,7 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( } std::vector operands; + operands.reserve(owned_operands.size()); for (auto& operand : owned_operands) { operands.push_back(operand.get()); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 67b6e215fcb23598f1a8ab6212d6e7e58a64e976..7557aaa2484d184555411a79d8dce2c9241427b0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -39,16 +39,18 @@ class HloEvaluator : public DfsHloVisitorWithDefault { HloEvaluator(); // Evaluates an HLO module and an array of pointers to literals. // Returns the evaluated result as a literal if successful. - // Precondition: argument literals correspond to each input computation's - // parameters in their post-ordering. See comment below for example. + // Precondition: The indices of arg_literals correspond to the parameter + // numbers of the HLO parameters in the computation. See comment below for an + // example. StatusOr> Evaluate( const HloModule& module, tensorflow::gtl::ArraySlice arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. - // Precondition: argument literals correspond to the input computation's - // parameters in their post-ordering. For e.g., consider the following graph: + // Precondition: The indices of arg_literals correspond to the parameter + // numbers of the HLO parameters in the computation. For e.g., consider the + // following graph: // // * // / \ @@ -57,8 +59,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // / \ // Parameter0 Constant // - // The input literals array will have its first literal map to Parameter0 and - // the second map to Parameter1. + // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number + // 1 in this computation. The input literals array will then have its first + // literal map to Parameter0 and the second map to Parameter1. StatusOr> Evaluate( const HloComputation& computation, tensorflow::gtl::ArraySlice arg_literals); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 85477af6fe26f53504c07204348566c16a24392c..a5d39fe08699f1ec17462f3ac5600fbe2191f307 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -46,20 +46,57 @@ class HloEvaluatorTest : public HloVerifiedTestBase { HloEvaluatorTest() { evaluator_ = MakeUnique(); } std::unique_ptr evaluator_; + + void TestUnaryOp(HloOpcode opcode, std::unique_ptr expected, + std::unique_ptr input, float aabs = 0) { + HloComputation::Builder b(TestName()); + auto c1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); + auto instruction = b.AddInstruction( + HloInstruction::CreateUnary(expected->shape(), opcode, c1)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + auto element_type = expected->shape().element_type(); + if (element_type == F32 || element_type == F64) { + ErrorSpec error(aabs); + LiteralTestUtil::ExpectNear(*expected, *result, error); + } else { + LiteralTestUtil::ExpectEqual(*expected, *result); + } + } + + void TestBinaryOp(HloOpcode opcode, std::unique_ptr expected, + std::unique_ptr lhs, + std::unique_ptr rhs) { + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); + auto instruction = b.AddInstruction( + HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + LiteralTestUtil::ExpectEqual(*expected, *result); + } }; // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. TEST_F(HloEvaluatorTest, DoesClamp) { auto low = Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); - auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto value = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); Shape shape = low->shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); - auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); - auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); auto instruction = b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); @@ -72,6 +109,28 @@ TEST_F(HloEvaluatorTest, DoesClamp) { LiteralTestUtil::ExpectEqual(*expected, *result); } +TEST_F(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { + auto low = Literal::CreateR0(0.f); + auto value = Literal::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); + auto high = Literal::CreateR0(1.f); + + Shape shape = value->shape(); + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); + auto instruction = b.AddInstruction( + HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. TEST_F(HloEvaluatorTest, DoesSelect) { @@ -103,120 +162,101 @@ TEST_F(HloEvaluatorTest, DoesSelect) { TEST_F(HloEvaluatorTest, DoesAdd) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - - Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); - auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); - auto instruction = b.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1, c2)); - module().AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - auto expected = Literal::CreateR2({{3, 4}, {-96, 8}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise and with 2 operands. +TEST_F(HloEvaluatorTest, DoesAnd) { + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); + auto expected = Literal::CreateR2({{0, 0}, {4, 4}}); + TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise or with 2 operands. +TEST_F(HloEvaluatorTest, DoesOr) { + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); + auto expected = Literal::CreateR2({{3, 4}, {-100, 4}}); + TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise multiply with 2 operands. +TEST_F(HloEvaluatorTest, DoesMultiply) { + auto lhs = Literal::CreateR2({{-1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2( + {{std::numeric_limits::min(), 4}, {4, 4}}); + auto expected = Literal::CreateR2( + {{std::numeric_limits::min(), 0}, {-400, 16}}); + TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs), + std::move(rhs)); } - // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. TEST_F(HloEvaluatorTest, DoesDivideInt64) { - auto lhs_s64 = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs_s64 = Literal::CreateR2({{2, 4}, {4, 4}}); - - Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1_s64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_s64))); - auto c2_s64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_s64))); - auto instruction = b.AddInstruction(HloInstruction::CreateBinary( - shape_s64, HloOpcode::kDivide, c1_s64, c2_s64)); - module().AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), + std::move(rhs)); } TEST_F(HloEvaluatorTest, DoesDivideDouble) { - auto lhs_f64 = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); - auto rhs_f64 = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); - - Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1_f64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_f64))); - auto c2_f64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_f64))); - auto instruction = b.AddInstruction(HloInstruction::CreateBinary( - shape_f64, HloOpcode::kDivide, c1_f64, c2_f64)); - module().AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - + auto lhs = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), + std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. TEST_F(HloEvaluatorTest, DoesAbsR2) { auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); - const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = - b.AddInstruction(HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1)); - module().AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_F(HloEvaluatorTest, DoesAbsR0) { - // For R0 literal. - const Shape& r0 = ShapeUtil::MakeShape(F32, {}); auto operand = Literal::CreateR0(-1.0f); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = - b.AddInstruction(HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1)); - module().AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); auto expected = Literal::CreateR0(1.0f); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_F(HloEvaluatorTest, DoesAbsR1WithZeroSize) { - // For R1 literal with dimension of size 0. - Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); auto operand = Literal::CreateR1({}); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = b.AddInstruction( - HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1)); - module().AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); auto expected = Literal::CreateR1({}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); +} +TEST_F(HloEvaluatorTest, DoesNegateR2) { + auto operand = Literal::CreateR2( + {{0, std::numeric_limits::min()}, {-1, 4}}); + auto expected = + Literal::CreateR2({{0, std::numeric_limits::min()}, {1, -4}}); + TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); +} +TEST_F(HloEvaluatorTest, DoesCosR2) { + auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); + TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand)); +} +TEST_F(HloEvaluatorTest, DoesSinR2) { + auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); + TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), + 0x1.0P-20); +} +TEST_F(HloEvaluatorTest, DoesNotR2) { + auto operand = + Literal::CreateR2({{0, std::numeric_limits::min()}, + {-1, std::numeric_limits::max()}}); + auto expected = + Literal::CreateR2({{-1, std::numeric_limits::max()}, + {0, std::numeric_limits::min()}}); + TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand)); } - // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { @@ -581,8 +621,11 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = @@ -624,8 +667,11 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = @@ -665,8 +711,11 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = @@ -711,7 +760,8 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_kernel_output_feature_dimension(0); dnums.set_kernel_input_feature_dimension(1); @@ -794,6 +844,85 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { LiteralTestUtil::ExpectEqual(*expected, *result); } +TEST_F(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { + HloComputation::Builder b(TestName()); + + // clang-format off + // Input dimensions: [feature=2, height=3, batch=1, width=4] + Array4D input({ + {{{1, 2, 3, 4}}, + {{5, 6, 7, 8}}, + {{9, 10, 11, 12}}}, + {{{13, 14, 15, 16}}, + {{17, 18, 19, 20}}, + {{21, 22, 23, 24}}} + }); + // Weight dimensions: + // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3] + Array4D weight({{ + {{1, 7, 13}, + {4, 10, 16}}, + {{2, 8, 14}, + {5, 11, 17}}, + {{3, 9, 15}, + {6, 12, 18}} + }}); + // clang-format on + + auto lhs_literal = Literal::CreateR4FromArray4D(input); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + auto rhs_literal = Literal::CreateR4FromArray4D(weight); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse( + rhs_instruction->shape(), rhs_instruction, {3, 1})); + + Window window; + WindowDimension dim; + dim.set_size(3); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + dim.set_window_reversal(true); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(2); + dnums.set_output_batch_dimension(2); + dnums.set_input_feature_dimension(0); + dnums.set_output_feature_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); + + dnums.set_kernel_output_feature_dimension(0); + dnums.set_kernel_input_feature_dimension(2); + dnums.add_kernel_spatial_dimensions(3); + dnums.add_kernel_spatial_dimensions(1); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + auto computation = module().AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + // clang-format off + // Result dimensions: [feature=1, height=1, batch=1, width=2] + Array4D expected_array({{{{2514, 2685}}}}); + // clang-format on + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { HloComputation::Builder b(TestName()); @@ -843,8 +972,10 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { dnums.set_output_batch_dimension(2); dnums.set_input_feature_dimension(0); dnums.set_output_feature_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); dnums.set_kernel_output_feature_dimension(0); dnums.set_kernel_input_feature_dimension(2); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index bf19bc9309b95f09fc5a36daf3e150f5191d1b8e..0809fe780d21baf366b63bdab118653630c33872 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -26,45 +26,110 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { - -void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, - uint64 cycles_taken) { - hlo_to_cycles_taken_[hlo] = cycles_taken; - profiled_computations_.insert(hlo->parent()); +HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { + size_t current_profile_index = 0; + for (xla::HloComputation* computation : module.MakeComputationPostOrder()) { + InsertOrDie(&computation_to_profile_idx_, computation, + current_profile_index++); + for (const HloInstruction* instruction : computation->instructions()) { + // For simplicity we track all instrutions here, but we could skip + // non-executing instructions like constants and parameters. + InsertOrDie(&instruction_to_profile_idx_, instruction, + current_profile_index++); + } + } } -uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const { - auto iter = hlo_to_cycles_taken_.find(&hlo); - if (iter == hlo_to_cycles_taken_.end()) { - return 0; +std::unique_ptr CreateHloProfilePrinter( + const HloProfileIndexMap& hlo_profile_index_map, + const HloCostAnalysis& cost_analysis) { + using HloComputationInfo = HloProfilePrinter::HloComputationInfo; + using HloInstructionInfo = HloProfilePrinter::HloInstructionInfo; + + HloComputationInfo* computation_infos = + new HloComputationInfo[hlo_profile_index_map.computation_count()]; + + // There are two "indices" in play here. The first one is the index of the + // HloComputationInfo or HloInstructionInfo in the array that contains said + // HloComputationInfo or HloInstructionInfo. The second index is the index of + // the HloComputationInfo or HloInstructionInfo in the profile counters array, + // as decided by hlo_profile_index_map. The latter index is always referred + // to as "profile_index". + + size_t computation_index_in_static_data = 0; + size_t max_profile_index = hlo_profile_index_map.total_count(); + for (const auto& pair : hlo_profile_index_map.computation_to_profile_idx()) { + CHECK_LT(pair.second, max_profile_index); + const HloComputation* computation = pair.first; + size_t current_computation_index = computation_index_in_static_data++; + HloComputationInfo* computation_info = + &computation_infos[current_computation_index]; + + computation_info->name = strdup(computation->name().c_str()); + computation_info->profile_index = pair.second; + computation_info->instructions = + new HloInstructionInfo[computation->instruction_count()]; + computation_info->instructions_size = computation->instruction_count(); + + size_t instruction_index_in_static_data = 0; + for (const HloInstruction* hlo : computation->instructions()) { + HloProfilePrinter::HloInstructionInfo* instruction_info = + &computation_info->instructions[instruction_index_in_static_data++]; + instruction_info->long_name = strdup(hlo->ToString().c_str()); + instruction_info->short_name = + strdup(hlo->ToString(/*compact_operands=*/true).c_str()); + instruction_info->category = strdup(hlo->ToCategory().c_str()); + instruction_info->flop_count = cost_analysis.flop_count(*hlo); + instruction_info->transcendental_count = + cost_analysis.transcendental_count(*hlo); + instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo); + instruction_info->optimal_seconds = cost_analysis.optimal_seconds(*hlo); + instruction_info->profile_index = + hlo_profile_index_map.GetProfileIndexFor(*hlo); + CHECK_LT(instruction_info->profile_index, max_profile_index); + } } - return iter->second; + + auto deleter = [](HloProfilePrinter::HloComputationInfo* computation_infos, + int64 computation_infos_size) { + for (int64 i = 0; i < computation_infos_size; i++) { + HloInstructionInfo* instruction_infos = computation_infos[i].instructions; + for (int64 j = 0; j < computation_infos[i].instructions_size; j++) { + // We can't make instruction_infos[j].long_name etc. non-const pointers + // since they may point into static storage, so we have a const_cast + // here. + free(const_cast(instruction_infos[j].long_name)); + free(const_cast(instruction_infos[j].short_name)); + free(const_cast(instruction_infos[j].category)); + } + delete[] instruction_infos; + free(const_cast(computation_infos[i].name)); + } + delete[] computation_infos; + }; + + return MakeUnique( + computation_infos, hlo_profile_index_map.computation_count(), + /*profile_counters_size=*/max_profile_index, deleter); } -string HloExecutionProfile::ToString( - const HloComputation& computation, - const DeviceDescription& device_description, - HloCostAnalysis* cost_analysis) const { - tensorflow::Status analysis_status = computation.Accept(cost_analysis); - if (!analysis_status.ok()) { - return ""; - } +HloExecutionProfile::HloExecutionProfile( + const HloProfilePrinter* hlo_profile_printer, + const HloProfileIndexMap* hlo_profile_index_map) + : hlo_profile_printer_(*hlo_profile_printer), + hlo_profile_index_map_(*hlo_profile_index_map), + profile_counters_( + /*count*/ hlo_profile_index_map_.total_count(), + /*value*/ 0) {} - HumanReadableProfileBuilder builder(computation.name(), - total_cycles_executed(computation), - device_description.clock_rate_ghz()); - for (const auto& item : hlo_to_cycles_taken_) { - const HloInstruction* hlo = item.first; - int64 cycles = item.second; - - builder.AddOp(/*op_name=*/hlo->ToString(), - /*short_name=*/hlo->ToString(/*compact_operands=*/true), - hlo->ToCategory(), cycles, cost_analysis->flop_count(*hlo), - cost_analysis->transcendental_count(*hlo), - cost_analysis->bytes_accessed(*hlo), - cost_analysis->seconds(*hlo)); - } - return builder.ToString(); +void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, + uint64 cycles_taken) { + profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(*hlo)] = + cycles_taken; +} + +uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const { + return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)]; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index cdce77cff427da376109db77c65ec70364e36140..470fd4ce3c205d84152238f4b18daad77e403f68 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -18,7 +18,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -27,6 +29,59 @@ namespace xla { class HloInstruction; +// Maps all HloInstructions and HloComputations in an HloModule to integers. +// These integers form the contiguous range [0, total_count()). +class HloProfileIndexMap { + public: + // Scans `module` to populate this instance of HloProfileIndexMap. + explicit HloProfileIndexMap(const HloModule& module); + + HloProfileIndexMap(const HloProfileIndexMap&) = default; + HloProfileIndexMap(HloProfileIndexMap&&) = default; + + HloProfileIndexMap& operator=(const HloProfileIndexMap&) = default; + HloProfileIndexMap& operator=(HloProfileIndexMap&&) = default; + + size_t GetProfileIndexFor(const HloInstruction& instruction) const { + return FindOrDie(instruction_to_profile_idx(), &instruction); + } + + size_t GetProfileIndexFor(const HloComputation& computation) const { + return FindOrDie(computation_to_profile_idx(), &computation); + } + + size_t instruction_count() const { + return instruction_to_profile_idx().size(); + } + + size_t computation_count() const { + return computation_to_profile_idx().size(); + } + + size_t total_count() const { + return instruction_count() + computation_count(); + } + + const std::unordered_map& + instruction_to_profile_idx() const { + return instruction_to_profile_idx_; + } + + const std::unordered_map& + computation_to_profile_idx() const { + return computation_to_profile_idx_; + } + + private: + std::unordered_map instruction_to_profile_idx_; + std::unordered_map computation_to_profile_idx_; +}; + +// Create an instance of `HloProfilePrinter` that owns its memory. +std::unique_ptr CreateHloProfilePrinter( + const HloProfileIndexMap& hlo_profile_index_map, + const HloCostAnalysis& cost_analysis); + // Describes how much time each HLO operation took. // // Each HloComputation takes a certain number of cycles. This class helps break @@ -35,6 +90,9 @@ class HloExecutionProfile { public: using DeviceDescription = perftools::gputools::DeviceDescription; + HloExecutionProfile(const HloProfilePrinter* hlo_profile_printer, + const HloProfileIndexMap* hlo_profile_index_map); + // Record how many cycles this HLO took to execute. void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken); @@ -44,17 +102,15 @@ class HloExecutionProfile { // Return the number of cycles this computation took to execute. uint64 total_cycles_executed(const HloComputation& computation) const { - auto it = total_cycles_executed_.find(&computation); - if (it != total_cycles_executed_.end()) { - return it->second; - } - return 0; + return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor( + computation)]; } // Record how many cycles a computation took to execute. void set_total_cycles_executed(const HloComputation& computation, uint64 total_cycles_executed) { - total_cycles_executed_[&computation] = total_cycles_executed; + profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(computation)] = + total_cycles_executed; } // Returns a version of the execution profile suitable for performance @@ -63,25 +119,20 @@ class HloExecutionProfile { // for the operations in a given computation. Returns an empty string if it // wasn't possible to generate a printable version. cost_analysis should be a // clean analysis that can be used to visit the computation. - string ToString(const HloComputation& computation, - const DeviceDescription& device_description, - HloCostAnalysis* cost_analysis) const; - - // Returns the computations we have profiled. - std::unordered_set profiled_computations() const { - return profiled_computations_; + string ToString(const DeviceDescription& device_description) const { + return hlo_profile_printer_.ToString(profile_counters_.data(), + device_description.clock_rate_ghz()); } - private: - // Contains a mapping from HLO to the number of cycles it took to execute it. - std::unordered_map hlo_to_cycles_taken_; + std::vector* mutable_profile_counters() { return &profile_counters_; } - // If non-empty, contains the total number of cycles a computation took to - // execute. - std::unordered_map total_cycles_executed_; + private: + const HloProfilePrinter& hlo_profile_printer_; + const HloProfileIndexMap& hlo_profile_index_map_; - // The computations we have profiled. - std::unordered_set profiled_computations_; + // Stores per-Hlo profile counters. This is the only thing that changes when + // we execute an XLA computation. + std::vector profile_counters_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b1e6729e2bccad4bdbe075a635d8a9b1ede6fecb --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +class HloExecutionProfileTest : public HloTestBase { + protected: + static constexpr int64 kInstructionCyclesIndex = 0; + static constexpr int64 kInstructionNameIndex = 19; +}; + +// Splits `lines` into a sequence of lines delimited by newlines and then split +// each of those lines into a sequence of words delimited by spaces. Filter out +// empty words. +std::vector> SplitIntoLinesAndWords( + tensorflow::StringPiece lines) { + std::vector> result; + for (const string& line : tensorflow::str_util::Split(lines, '\n')) { + std::vector words; + for (const string& word : tensorflow::str_util::Split(line, ' ')) { + if (!word.empty()) { + words.push_back(word); + } + } + result.push_back(std::move(words)); + } + + return result; +} + +TEST_F(HloExecutionProfileTest, Basic) { + std::unique_ptr hlo_module = CreateNewModule(); + + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {30, 30}); + HloInstruction* param_lhs = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); + HloInstruction* param_rhs = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); + HloInstruction* add_instruction = + builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, param_lhs, param_rhs)); + HloInstruction* dot_instruction = + builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, param_lhs, add_instruction)); + + hlo_module->AddEntryComputation(builder.Build()); + + auto shape_size_function = [&](const Shape& shape) { + const int64 pointer_size = 8; + if (ShapeUtil::IsOpaque(shape)) { + return pointer_size; + } + return ShapeUtil::ByteSizeOf(shape, pointer_size); + }; + + HloCostAnalysis cost_analysis(shape_size_function); + HloProfileIndexMap profile_index_map(*hlo_module); + std::unique_ptr profile_printer = + CreateHloProfilePrinter(profile_index_map, cost_analysis); + HloExecutionProfile execution_profile(profile_printer.get(), + &profile_index_map); + + const int64 add_cycles = 1000; + const int64 dot_cycles = 4000; + + execution_profile.SetCyclesTakenBy(add_instruction, add_cycles); + execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles); + + string rendered_profile = execution_profile.ToString( + backend().default_stream_executor()->GetDeviceDescription()); + std::vector> lines_and_words = + SplitIntoLinesAndWords(rendered_profile); + ASSERT_EQ(lines_and_words.size(), 8); + + const std::vector& line_2 = lines_and_words[2]; + const std::vector& line_3 = lines_and_words[3]; + + EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles)); + EXPECT_EQ(line_2[kInstructionNameIndex], '%' + dot_instruction->name()); + + EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles)); + EXPECT_EQ(line_3[kInstructionNameIndex], '%' + add_instruction->name()); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index fd162622ce2a56bcfbcd4fa1c56d5afc56249a8f..84187d578346eafd5e32727a15f5eab9cc79feef 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -312,11 +312,11 @@ optional MatchTrivialComputation(const HloComputation* computation) { class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, - bool show_addresses, bool show_metadata, + const DebugOptions& debug_options, bool show_metadata, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), label_(label.ToString()), - show_addresses_(show_addresses), + debug_options_(debug_options), show_metadata_(show_metadata), profile_(profile), filter_(std::move(filter)) {} @@ -382,7 +382,7 @@ class HloDotDumper { const HloComputation* computation_; // never null const string label_; // overall name for the graph - const bool show_addresses_; + const DebugOptions& debug_options_; const bool show_metadata_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -414,6 +414,11 @@ class HloDotDumper { // appears before both the inner computation and the destination node are // defined. std::vector edges_; + + // When coloring by sharding information, we track the sharding string + // representation to color association, by round-robin the color schemes. + std::unordered_map sharding_colors_; + int64 next_shard_color_ = 0; }; string HloDotDumper::Dump() { @@ -734,15 +739,16 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); AddInstructionIncomingEdges(instr); - // Override the node's styling if it should be (de-)emphasized. - if (filter_.Deemphasized(instr)) { - color = kDashedBorder; - } - if (filter_.Highlight(instr)) { - node_shape = "diamond"; - color = kDarkRed; + if (!debug_options_.xla_hlo_graph_sharding_color()) { + // Override the node's styling if it should be (de-)emphasized. + if (filter_.Deemphasized(instr)) { + color = kDashedBorder; + } + if (filter_.Highlight(instr)) { + node_shape = "diamond"; + color = kDarkRed; + } } - // Build the text that will be displayed inside the node. string node_body = node_label; for (const string& s : @@ -761,12 +767,22 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { auto stringify_constant = [](const HloInstruction* constant) { - if (ShapeUtil::IsEffectiveScalar(constant->shape())) { - auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( - constant->shape(), /*linear_index=*/0); - return Printf("%s (%s)", constant->literal().GetAsString(elem_idx), + const auto& shape = constant->shape(); + + // Print the literal value of constants with <= K elements. + optional elem_count; + if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { + elem_count = 1; + for (int64 dim : shape.dimensions()) { + *elem_count *= dim; + } + } + if (elem_count.has_value() && *elem_count <= 8) { + return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } + + // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { constant_name = constant->name(); @@ -817,6 +833,20 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( } ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { + if (debug_options_.xla_hlo_graph_sharding_color()) { + if (!instr->has_sharding()) { + return kDashedBorder; + } + string shard_str = instr->sharding().ToString(); + auto it = sharding_colors_.find(shard_str); + if (it != sharding_colors_.end()) { + return it->second; + } + ColorScheme color = static_cast( + kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); + sharding_colors_.emplace(shard_str, color); + return color; + } const auto kParameterColor = kOrange; // Special case: If this instruction has a parameter merged into it, paint it @@ -834,9 +864,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { // (eg, parameter). switch (instr->opcode()) { case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kAnd: case HloOpcode::kAtan2: + case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kComplex: @@ -852,18 +883,19 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: - case HloOpcode::kAnd: - case HloOpcode::kNot: - case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: + case HloOpcode::kRng: + case HloOpcode::kRoundNearestAfz: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -873,7 +905,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kSort: case HloOpcode::kSubtract: case HloOpcode::kTanh: - case HloOpcode::kRng: // De-emphasize scalar-shaped elementwise ops -- they're generally // uninteresting. if (ShapeUtil::IsEffectiveScalar(instr->shape())) { @@ -881,9 +912,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { } return kYellow; case HloOpcode::kBitcast: - case HloOpcode::kTuple: - case HloOpcode::kTrace: case HloOpcode::kGetTupleElement: + case HloOpcode::kTrace: + case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: // De-emphasize nodes which broadcast a scalar within a fusion node -- @@ -922,25 +953,28 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kRed; case HloOpcode::kParameter: return kParameterColor; - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: case HloOpcode::kReduce: - case HloOpcode::kSelectAndScatter: case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: return kPurple; - case HloOpcode::kMap: case HloOpcode::kFusion: + case HloOpcode::kMap: return kGray; - case HloOpcode::kSend: - case HloOpcode::kRecv: + case HloOpcode::kCrossReplicaSum: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: return kBrown; + case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kCustomCall: case HloOpcode::kWhile: - case HloOpcode::kCall: return kDarkGreen; case HloOpcode::kConstant: LOG(FATAL) << "Constants don't get their own nodes in the graph."; @@ -969,10 +1003,13 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { .starts_with(StrCat("%", HloOpcodeString(instr->opcode())))) { return Printf("%s", HtmlLikeStringSanitize(instr->name())); } - + string extended_opcode = + StrCat(HloOpcodeString(instr->opcode()), + instr->opcode() != HloOpcode::kFusion + ? "" + : StrCat(":", xla::ToString(instr->fusion_kind()))); // If the name does not contain the opcode, render both. - return Printf("%s
%s", - HtmlLikeStringSanitize(instr->ExtendedOpcodeStr()), + return Printf("%s
%s", HtmlLikeStringSanitize(extended_opcode), HtmlLikeStringSanitize(instr->name())); } @@ -1027,7 +1064,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { ? "" : StrCat("stride=", VectorString(instr->slice_strides())); case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: return StrCat("channel_id=", instr->channel_id()); default: return ""; @@ -1065,8 +1104,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { } lines.push_back(instr_shape); } - - if (show_addresses_) { + if (debug_options_.xla_hlo_graph_addresses()) { lines.push_back(Printf("[%p]", instr)); } if (profile_ != nullptr) { @@ -1163,70 +1201,36 @@ const HloInstruction* HloDotDumper::GetNodeForEdge( return instr; } -tensorflow::mutex& RendererMutex() { - static tensorflow::mutex* mu = new tensorflow::mutex; - return *mu; -} +class GraphRendererRegistry { + public: + void AddRenderer(GraphRendererInterface* graph_renderer) { + tensorflow::mutex_lock lock(mu_); + graph_renderer_ = graph_renderer; + } -std::map* GraphRenderers() { - static auto* graph_renderers = new std::map(); - return graph_renderers; -} + GraphRendererInterface* GetDefaultRenderer() { + tensorflow::mutex_lock lock(mu_); + return graph_renderer_; + } -GraphRendererInterface* GetGraphRenderer() { - tensorflow::mutex_lock lock(RendererMutex()); - auto* graph_renderers = GraphRenderers(); - auto it = graph_renderers->rbegin(); - CHECK(it != graph_renderers->rend()) << "No registered graph dumpers"; - return it->second; -} + static GraphRendererRegistry* Default() { + static GraphRendererRegistry* registry = new GraphRendererRegistry(); + return registry; + } + + private: + tensorflow::mutex mu_; + GraphRendererInterface* graph_renderer_ = nullptr; +}; } // namespace -Registrar::Registrar(GraphRendererInterface* dumper, int priority) { - tensorflow::mutex_lock lock(RendererMutex()); - auto* graph_renderers = GraphRenderers(); - graph_renderers->emplace(priority, dumper); +Registrar::Registrar(GraphRendererInterface* dumper) { + GraphRendererRegistry::Default()->AddRenderer(dumper); } namespace { -class FileGraphRenderer : public GraphRendererInterface { - public: - string RenderGraph(const string& graph, GraphKind graph_kind, - const DebugOptions& debug_options) override { - static std::atomic output_num(0); - string file_extension; - switch (graph_kind) { - case DOT_GRAPH: - file_extension = ".dot"; - break; - case TF_GRAPHDEF: - file_extension = ".pbtxt"; - break; - } - string path = - JoinPath(debug_options.xla_hlo_graph_path(), - StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); - auto status = Status::OK(); - int fd = mkstemps(&path[0], file_extension.length()); - if (fd < 0) { - status = - Status(tensorflow::error::Code::UNKNOWN, - StrCat("Failed to create temporary file to dump HLO graph: ", - strerror(errno))); - } else { - status = tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, - graph); - close(fd); - } - if (!status.ok()) { - LOG(WARNING) << "Saving HLO graph failed: " << status; - } - return path; - } -}; - // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { @@ -1289,7 +1293,9 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { auto is_displayed = [&](const HloInstruction* instr) { // Constants are displayed inline with their users; they're never omitted. - return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant; + // Nodes in subcomputations are always shown. + return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant || + instr->parent() != root->parent(); }; // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we @@ -1334,7 +1340,54 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { }); } -XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); +string SaveGraph(const string& graph, + GraphRendererInterface::GraphKind graph_kind, + const string& dest_path) { + static std::atomic output_num(0); + string file_extension; + switch (graph_kind) { + case GraphRendererInterface::DOT_GRAPH: + file_extension = ".dot"; + break; + case GraphRendererInterface::TF_GRAPHDEF: + file_extension = ".pbtxt"; + break; + } + string path = JoinPath( + dest_path, StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); + auto status = Status::OK(); + int fd = mkstemps(&path[0], file_extension.length()); + if (fd < 0) { + status = + Status(tensorflow::error::Code::UNKNOWN, + StrCat("Failed to create temporary file to dump HLO graph: ", + strerror(errno))); + } else { + status = + tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, graph); + close(fd); + } + if (!status.ok()) { + LOG(WARNING) << "Saving HLO graph failed: " << status; + } + return path; +} + +string ExportGraph(const string& graph, + GraphRendererInterface::GraphKind graph_kind, + const DebugOptions& debug_options) { + string path = debug_options.xla_hlo_graph_path(); + if (!path.empty()) { + return SaveGraph(graph, graph_kind, path); + } else { + auto graph_renderer = + GraphRendererRegistry::Default()->GetDefaultRenderer(); + CHECK(graph_renderer != nullptr) + << "No registered renderer for the HLO graph. " + "Use --xla_hlo_graph_path=PATH to export to local file system"; + return graph_renderer->RenderGraph(graph, graph_kind, debug_options); + } +} } // namespace @@ -1342,27 +1395,22 @@ string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, bool show_metadata) { + GraphRendererInterface::GraphKind graph_kind; string graph; - string graph_url; if (debug_options.xla_hlo_dump_as_graphdef()) { - HloTfGraphBuilder builder; + HloTfGraphBuilder builder(debug_options); TF_CHECK_OK(builder.AddComputation(computation)); CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), &graph)); - // TODO(b/37198616): Use the default registered renderers when all - // renderers support rendering GraphDefs. Always dump GraphDefs to files - // for now. - graph_url = FileGraphRenderer().RenderGraph( - graph, GraphRendererInterface::TF_GRAPHDEF, debug_options); + graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { - graph = - HloDotDumper(&computation, label, - /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), - show_metadata, hlo_execution_profile, NodeFilter()) - .Dump(); - graph_url = GetGraphRenderer()->RenderGraph( - graph, GraphRendererInterface::DOT_GRAPH, debug_options); + graph = HloDotDumper(&computation, label, debug_options, show_metadata, + hlo_execution_profile, NodeFilter()) + .Dump(); + graph_kind = GraphRendererInterface::DOT_GRAPH; } + + string graph_url = ExportGraph(graph, graph_kind, debug_options); LOG(INFO) << "computation " << computation.name() << " [" << label << "]: " << graph_url; return graph_url; @@ -1375,12 +1423,10 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); string graph = - HloDotDumper(node.parent(), label, - /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), - show_metadata, /*profile=*/nullptr, filter) + HloDotDumper(node.parent(), label, debug_options, show_metadata, + /*profile=*/nullptr, filter) .Dump(); - return GetGraphRenderer()->RenderGraph( - graph, GraphRendererInterface::DOT_GRAPH, debug_options); + return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } void DumpText(const HloModule& module, const string& label, diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index dd304ec76cd903a6175337551fc50808b1797104..2704aae1e3ba7fb131bfcb1287d807d785fd9774 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -84,11 +84,10 @@ void DumpText(const HloModule& module, const string& label, // Internal implementation details below this point. -// Class that registers a graph renderer. Higher-priority renders are chosen -// first. +// Class that registers a graph renderer. class Registrar { public: - Registrar(GraphRendererInterface* dumper, int priority); + Registrar(GraphRendererInterface* dumper); }; #define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 7b0f937f383a416f805a799bd6787afe15b324b0..8e1531c87f9c6e133e2d6763b046b1d5dcbcd09f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -45,7 +45,7 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface { string last_graph_; }; -XLA_REGISTER_GRAPH_RENDERER(DotRenderer, std::numeric_limits::max()); +XLA_REGISTER_GRAPH_RENDERER(DotRenderer); TEST(HloGraphDumperTest, NestedFusion) { HloComputation::Builder b("b"); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5107ac782d7c93dfa17969338bf97c9fd9bb1516..220d5044a29a8ab724cf56394a9fbf7c6e4010e4 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -43,6 +43,7 @@ limitations under the License. namespace xla { +using tensorflow::str_util::CEscape; using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; @@ -51,7 +52,9 @@ using ::tensorflow::strings::StrCat; StatusOr> HloInstruction::CreateFromProto( HloModule* module, const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, - tensorflow::gtl::FlatMap* computation_map) { + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); @@ -77,19 +80,19 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(!proto.fusion_kind().empty()); TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, StringToFusionKind(proto.fusion_kind())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr fused_computation, - HloComputation::CreateFromProto( - module, proto.fused_instructions_computation(), computation_map, - /*fusion_instruction=*/instruction.get())); - instruction->called_computations_.push_back( - module->AddEmbeddedComputation(std::move(fused_computation))); + TF_ASSIGN_OR_RETURN(std::unique_ptr fused_computation, + HloComputation::CreateFromProto( + module, proto.fused_instructions_computation(), + computation_map, add_fused_computation, + /*fusion_instruction=*/instruction.get())); + instruction->called_computations_.push_back(fused_computation.get()); + add_fused_computation(std::move(fused_computation)); } else { for (const string& computation_name : proto.called_computation_names()) { - TF_RET_CHECK(ContainsKey(*computation_map, computation_name)) + TF_RET_CHECK(ContainsKey(computation_map, computation_name)) << "No computation named " << computation_name; instruction->called_computations_.push_back( - computation_map->at(computation_name)); + computation_map.at(computation_name)); } } @@ -115,6 +118,10 @@ StatusOr> HloInstruction::CreateFromProto( MakeUnique( proto.convolution_dimension_numbers()); } + if (proto.has_dot_dimension_numbers()) { + instruction->dot_dimension_numbers_ = + MakeUnique(proto.dot_dimension_numbers()); + } for (const HloInstructionProto::SliceDimensions& slice_dimensions : proto.slice_dimensions()) { instruction->slice_starts_.push_back(slice_dimensions.start()); @@ -148,7 +155,7 @@ StatusOr> HloInstruction::CreateFromProto( WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); instruction->parameter_number_ = parameter_number; instruction->parameter_name_ = name; - instruction->name_ = "%" + name; + instruction->name_ = name; return instruction; } @@ -329,6 +336,31 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + instruction->AppendOperand(lhs); + instruction->AppendOperand(rhs); + instruction->dot_dimension_numbers_ = + MakeUnique(dimension_numbers); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateCanonicalDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { + CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); + CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + instruction->AppendOperand(lhs); + instruction->AppendOperand(rhs); + instruction->dot_dimension_numbers_ = MakeUnique(); + instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); + instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, @@ -343,12 +375,9 @@ HloInstruction::CreateReducePrecision(const Shape& shape, } /* static */ std::unique_ptr -HloInstruction::CreateCrossReplicaSum(const Shape& shape, - HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); - instruction->AppendOperand(operand); - return instruction; +HloInstruction::CreateCrossReplicaSum( + const Shape& shape, tensorflow::gtl::ArraySlice operands) { + return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( @@ -371,20 +400,50 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, int64 channel_id) { + // Send instruction produces a tuple of {aliased operand, U32 context}. + Shape output_shape = ShapeUtil::MakeTupleShape( + {operand->shape(), ShapeUtil::MakeShape(U32, {})}); auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil())); + WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape)); instruction->AppendOperand(operand); instruction->channel_id_ = channel_id; return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateSendDone( + HloInstruction* operand) { + CHECK(operand->opcode() == HloOpcode::kSend) + << "SendDone must take the context operand from Send"; + auto instruction = WrapUnique( + new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil())); + instruction->AppendOperand(operand); + instruction->channel_id_ = operand->channel_id(); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, int64 channel_id) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRecv, shape)); + // Recv instruction produces a tuple of {receive buffer, U32 context}. + Shape output_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape)); instruction->channel_id_ = channel_id; return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateRecvDone( + HloInstruction* operand) { + CHECK(operand->opcode() == HloOpcode::kRecv) + << "RecvDone must take the context operand from Recv"; + Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape)); + instruction->AppendOperand(operand); + instruction->channel_id_ = operand->channel_id(); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { @@ -405,6 +464,23 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateConditional( + const Shape& shape, HloInstruction* pred, + HloInstruction* true_computation_arg, HloComputation* true_computation, + HloInstruction* false_computation_arg, HloComputation* false_computation) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); + instruction->AppendOperand(pred); + instruction->AppendOperand(true_computation_arg); + instruction->AppendOperand(false_computation_arg); + // In called_computations_, the index of true_computation must be 0 and that + // of false computation must be 1, as defined by kTrueComputationIndex and + // kFalseComputationIndex. + instruction->called_computations_.push_back(true_computation); + instruction->called_computations_.push_back(false_computation); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateSlice( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, @@ -468,6 +544,15 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateBitcastConvert(const Shape& shape, + HloInstruction* operand) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); + instruction->AppendOperand(operand); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateReduce( const Shape& shape, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, @@ -600,7 +685,10 @@ HloInstruction::CreateSelectAndScatter( CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); CHECK(std::equal(operand->shape().dimensions().begin(), operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())); + Permute(dimensions, shape.dimensions()).begin())) + << "shape: " << ShapeUtil::HumanString(shape) + << ", operand->shape(): " << ShapeUtil::HumanString(shape) + << ", dimensions: {" << Join(dimensions, ", ") << "}"; auto instruction = WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape)); instruction->AppendOperand(operand); @@ -618,6 +706,20 @@ HloInstruction::CreateSelectAndScatter( return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateFusion( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->fusion_kind_ = fusion_kind; + instruction->called_computations_.push_back(fusion_computation); + fusion_computation->SetFusionInstruction(instruction.get()); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateFusionForBackwardConvolution( const Shape& shape, FusionKind fusion_kind, const Window& window, @@ -746,7 +848,7 @@ HloInstruction* HloInstruction::FuseInstructionInternal( HloInstruction* HloInstruction::CloneAndFuseInternal( HloInstruction* instruction_to_fuse, bool add_output) { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(instruction_to_fuse->IsFusable()); + CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); HloInstruction* clone = nullptr; if (called_computations_.empty()) { @@ -824,10 +926,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // parameter instruction. int64 param_no = fused_parameters.size(); // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. Strip the leading "%" from the operand name - // to avoid a double %%. - string param_name = - StrCat(operand->name().substr(1), ".param_", param_no); + // (non-fusion) computation. + string param_name = StrCat(operand->name(), ".param_", param_no); fused_param = fused_instructions_computation()->AddParameter( CreateParameter(param_no, operand->shape(), param_name)); AppendOperand(operand); @@ -908,7 +1008,10 @@ RandomDistribution HloInstruction::random_distribution() const { bool HloInstruction::HasSideEffect() const { switch (opcode_) { case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kRng: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: @@ -966,7 +1069,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { - VLOG(3) << " " << new_operand->name(); + VLOG(3) << " %" << new_operand->name(); } std::unique_ptr clone; @@ -1010,7 +1113,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kLe: case HloOpcode::kLt: case HloOpcode::kNe: - case HloOpcode::kDot: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -1048,6 +1150,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateConvert(shape, new_operands[0]); break; + case HloOpcode::kBitcastConvert: + CHECK_EQ(new_operands.size(), 1); + clone = CreateBitcastConvert(shape, new_operands[0]); + break; case HloOpcode::kReducePrecision: CHECK_EQ(new_operands.size(), 1); clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_, @@ -1058,9 +1164,13 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, *convolution_dimension_numbers_); break; + case HloOpcode::kDot: + CHECK_EQ(new_operands.size(), 2); + clone = CreateDot(shape, new_operands[0], new_operands[1], + *dot_dimension_numbers_); + break; case HloOpcode::kCrossReplicaSum: - CHECK_EQ(new_operands.size(), 1); - clone = CreateCrossReplicaSum(shape, new_operands[0]); + clone = CreateCrossReplicaSum(shape, new_operands); break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); @@ -1163,8 +1273,11 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); break; + case HloOpcode::kConditional: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } @@ -1353,7 +1466,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { return i; } } - LOG(FATAL) << "target was not an operand"; + LOG(FATAL) << "target was not an operand: " << target->ToString(); } Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { @@ -1426,7 +1539,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kFloor: @@ -1485,6 +1597,7 @@ bool HloInstruction::IdenticalSlowPath( // A convert result is determined by the primitive type that the operand is // converted into. case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: return shape().element_type() == other.shape().element_type(); // A reduce-precision operation is determined by the bit sizes. @@ -1498,6 +1611,10 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals( convolution_dimension_numbers(), other.convolution_dimension_numbers()); + // Check dot dimension numbers. + case HloOpcode::kDot: + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + other.dot_dimension_numbers()); // Reduction results are determined by the reduction dimension and the // reduction computation. @@ -1554,11 +1671,14 @@ bool HloInstruction::IdenticalSlowPath( return dimensions() == other.dimensions(); // These opcodes are not yet supported. + case HloOpcode::kConditional: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: - case HloOpcode::kSend: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: return false; } } @@ -1761,6 +1881,32 @@ void HloInstruction::set_scatter(HloComputation* computation) { called_computations_[kScatterComputationIndex] = computation; } +HloComputation* HloInstruction::true_computation() const { + CHECK_EQ(HloOpcode::kConditional, opcode_); + return called_computations_[kTrueComputationIndex]; +} + +HloComputation* HloInstruction::false_computation() const { + CHECK_EQ(HloOpcode::kConditional, opcode_); + return called_computations_[kFalseComputationIndex]; +} + +void HloInstruction::set_true_computation(HloComputation* true_computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + CHECK_EQ(HloOpcode::kConditional, opcode_); + called_computations_[kTrueComputationIndex] = true_computation; +} + +void HloInstruction::set_false_computation(HloComputation* false_computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + CHECK_EQ(HloOpcode::kConditional, opcode_); + called_computations_[kFalseComputationIndex] = false_computation; +} + string HloInstruction::SignatureString() const { string operands = Join(operands_, ", ", [](string* out, HloInstruction* operand) { @@ -1769,20 +1915,11 @@ string HloInstruction::SignatureString() const { return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); } -string HloInstruction::ExtendedOpcodeStr() const { - string opc_name = HloOpcodeString(opcode()); - HloOpcode opc = opcode(); - if (HloOpcode::kFusion == opc) { - opc_name += ":" + xla::ToString(fusion_kind()); - } - return opc_name; -} - string HloInstruction::ToString(bool compact_operands, bool include_metadata, bool include_large_constants) const { string result = - StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", - ExtendedOpcodeStr(), "(", + StrCat("%", name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", + HloOpcodeString(opcode()), "(", OperandsToString(compact_operands, include_large_constants), ")"); for (const string& extra : ExtraAttributesToString()) { StrAppend(&result, ", ", extra); @@ -1790,7 +1927,7 @@ string HloInstruction::ToString(bool compact_operands, bool include_metadata, if (include_metadata && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { - StrAppend(&result, " # metadata=", metadata_.ShortDebugString()); + StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } return result; } @@ -1833,7 +1970,7 @@ string HloInstruction::OperandsToString(bool compact, operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { *out += ShapeUtil::HumanStringWithLayout(operand->shape()); if (!compact) { - StrAppend(out, " ", operand->name()); + StrAppend(out, " %", operand->name()); } }); const int64 remaining = operands_.size() - slice.size(); @@ -1846,16 +1983,20 @@ string HloInstruction::OperandsToString(bool compact, std::vector HloInstruction::ExtraAttributesToString() const { std::vector extra; + if (opcode() == HloOpcode::kFusion) { + extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); + } if (CanHaveDimensionsField()) { extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); } - if (window_ != nullptr) { - extra.push_back(window_util::ToString(*window_)); + if (window_ != nullptr && window_->dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); } if (padding_config_ != nullptr) { - extra.push_back(StrCat("padding=", padding_config_->ShortDebugString())); + extra.push_back( + StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); } - if (!slice_starts_.empty() && !slice_limits_.empty()) { + if (opcode() == HloOpcode::kSlice) { std::vector bounds; bounds.reserve(slice_starts_.size()); const bool omit_stride = @@ -1868,10 +2009,23 @@ std::vector HloInstruction::ExtraAttributesToString() const { } extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); } + if (opcode() == HloOpcode::kDynamicSlice) { + extra.push_back( + StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); + } + if (opcode() == HloOpcode::kBatchNormTraining || + opcode() == HloOpcode::kBatchNormInference || + opcode() == HloOpcode::kBatchNormGrad) { + extra.push_back(StrCat("epsilon=", epsilon())); + extra.push_back(StrCat("feature_index=", feature_index())); + } if (convolution_dimension_numbers_ != nullptr) { extra.push_back(ConvolutionDimensionNumbersToString()); } + if (dot_dimension_numbers_ != nullptr) { + extra.push_back(DotDimensionNumbersToString()); + } if (opcode() == HloOpcode::kWhile) { extra.push_back(StrCat("condition=%", while_condition()->name())); @@ -1879,6 +2033,9 @@ std::vector HloInstruction::ExtraAttributesToString() const { } else if (opcode() == HloOpcode::kSelectAndScatter) { extra.push_back(StrCat("select=%", select()->name())); extra.push_back(StrCat("scatter=%", scatter()->name())); + } else if (opcode() == HloOpcode::kConditional) { + extra.push_back(StrCat("true_computation=%", true_computation()->name())); + extra.push_back(StrCat("false_computation=%", false_computation()->name())); } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce) { @@ -1891,7 +2048,8 @@ std::vector HloInstruction::ExtraAttributesToString() const { }))); } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) { + if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || + opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { extra.push_back(StrCat("channel_id=", channel_id_)); } @@ -1905,18 +2063,37 @@ std::vector HloInstruction::ExtraAttributesToString() const { extra.push_back(StrCat("control-predecessors={", Join(control_predecessors_, ", ", [](string* out, HloInstruction* pre) { - StrAppend(out, pre->name()); + StrAppend(out, "%", pre->name()); }), "}")); } + if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) { + extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")); + } + if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) { + extra.push_back( + StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); + } + if (opcode() == HloOpcode::kRng) { + extra.push_back( + StrCat("distribution=", RandomDistributionToString(distribution_))); + } + if (opcode() == HloOpcode::kReducePrecision) { + extra.push_back(StrCat("exponent_bits=", exponent_bits_)); + extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); + } + if (opcode() == HloOpcode::kCustomCall) { + extra.push_back( + StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + } return extra; } string HloInstruction::ToShortString() const { - return StrCat(name(), " = ", HloOpcodeString(opcode()), "(", + return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(", Join(operands_, ", ", [](string* out, HloInstruction* operand) { - StrAppend(out, operand->name()); + StrAppend(out, "%", operand->name()); }), ")"); } @@ -1960,6 +2137,9 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = *convolution_dimension_numbers_; } + if (dot_dimension_numbers_ != nullptr) { + *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; + } for (int i = 0; i < slice_starts_.size(); ++i) { auto* slice_dimension = proto.add_slice_dimensions(); slice_dimension->set_start(slice_starts_[i]); @@ -2010,8 +2190,10 @@ string HloInstruction::ToCategory() const { bool saw_rank_1 = false; bool saw_higher_rank = false; for (const auto* operand : operands()) { - saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1; - saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1; + if (!ShapeUtil::IsTuple(operand->shape())) { + saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1; + saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1; + } } if (saw_rank_1 && saw_higher_rank) { return "rank-1-broadcast binary fusion"; @@ -2064,23 +2246,13 @@ bool HloInstruction::IsFusable() const { if (tracing()) { return false; } - // Some kinds of instructions don't make sense to fuse. switch (opcode_) { - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: case HloOpcode::kParameter: - case HloOpcode::kTrace: - case HloOpcode::kSend: - case HloOpcode::kRecv: return false; - // Only fuse Rng if it is used once, otherwise the random numbers generated - // will be different in each fusion. If it is the root (user count = 0) - // then it is the equivalent of having one user. - case HloOpcode::kRng: - return users_.size() <= 1; + // Side effecting instrutions cannot be fused. default: - return true; + return !HasSideEffect(); } } @@ -2131,7 +2303,7 @@ HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) : unique_id_(-1), opcode_(opcode), shape_(shape), - name_("%" + HloOpcodeString(opcode)) { + name_(HloOpcodeString(opcode)) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); } @@ -2191,6 +2363,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleConcatenate(this); case HloOpcode::kConvert: return visitor->HandleConvert(this); + case HloOpcode::kBitcastConvert: + return visitor->HandleBitcastConvert(this); case HloOpcode::kCopy: return visitor->HandleCopy(this); case HloOpcode::kMultiply: @@ -2277,12 +2451,18 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleFusion(this); case HloOpcode::kCall: return visitor->HandleCall(this); + case HloOpcode::kConditional: + return visitor->HandleConditional(this); case HloOpcode::kCustomCall: return visitor->HandleCustomCall(this); - case HloOpcode::kSend: - return visitor->HandleSend(this); case HloOpcode::kRecv: return visitor->HandleRecv(this); + case HloOpcode::kRecvDone: + return visitor->HandleRecvDone(this); + case HloOpcode::kSend: + return visitor->HandleSend(this); + case HloOpcode::kSendDone: + return visitor->HandleSendDone(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2350,7 +2530,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, visitor->GetVisitState(current_id); if (visit_state == Visitor::kVisited) { dfs_stack.pop_back(); - VLOG(3) << "Not visiting HLO " << current_node->name() + VLOG(3) << "Not visiting HLO %" << current_node->name() << " as it was already visited."; continue; } @@ -2359,7 +2539,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, dfs_stack.pop_back(); TF_RETURN_IF_ERROR(visitor->Preprocess(current_node)); - VLOG(2) << "Visiting HLO " << current_node->name(); + VLOG(2) << "Visiting HLO %" << current_node->name(); TF_RETURN_IF_ERROR(current_node->Visit(visitor)); visitor->SetVisitState(current_id, Visitor::kVisited); TF_RETURN_IF_ERROR(visitor->Postprocess(current_node)); @@ -2404,7 +2584,7 @@ template Status HloInstruction::Accept(DfsHloVisitorBase* visitor, bool call_finish_visit, bool ignore_control_predecessors) { - VLOG(3) << "HloInstruction::Accept(" << name() << ")"; + VLOG(3) << "HloInstruction::Accept(%" << name() << ")"; TF_RETURN_IF_ERROR( PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors)); if (call_finish_visit) { @@ -2420,7 +2600,7 @@ template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool); Status HloInstruction::AcceptWithOperandOrder( DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit) { - VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")"; + VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")"; InternalCompareFunction func = [&operand_order]( std::pair a, std::pair b) { @@ -2483,7 +2663,7 @@ Status HloInstruction::Accept( Status HloInstruction::AcceptOrdered( DfsHloVisitor* visitor, const std::vector& order) { - VLOG(2) << "HloInstruction::AcceptOrdered(" << name() << ")"; + VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")"; TF_RET_CHECK(OrderIsTopologicalSort(order)); // Compute the predecessors of this instruction. @@ -2502,7 +2682,7 @@ Status HloInstruction::AcceptOrdered( // The visitor can mark instructions as visited to skip particular // instructions. if (visitor->DidVisit(*const_instruction)) { - VLOG(3) << "Not visiting HLO " << const_instruction->name() + VLOG(3) << "Not visiting HLO %" << const_instruction->name() << " as it was already visited."; continue; } @@ -2511,7 +2691,7 @@ Status HloInstruction::AcceptOrdered( const_cast(const_instruction); TF_RETURN_IF_ERROR(visitor->Preprocess(instruction)); - VLOG(2) << "Visiting HLO " << instruction->name(); + VLOG(2) << "Visiting HLO %" << instruction->name(); TF_RETURN_IF_ERROR(instruction->Visit(visitor)); visitor->SetVisited(*instruction); TF_RETURN_IF_ERROR(visitor->Postprocess(instruction)); @@ -2557,6 +2737,7 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kRoundNearestAfz: case HloOpcode::kCeil: case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: @@ -2841,6 +3022,61 @@ StatusOr StringToFusionKind( return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); } +string PaddingConfigToString(const PaddingConfig& padding) { + bool has_interior_padding = + std::any_of(padding.dimensions().begin(), padding.dimensions().end(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.interior_padding() != 0; + }); + return Join( + padding.dimensions(), "x", + [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { + StrAppend( + out, dim.edge_padding_low(), "_", dim.edge_padding_high(), + has_interior_padding ? StrCat("_", dim.interior_padding()) : ""); + }); +} + +string OpMetadataToString(const OpMetadata& metadata) { + std::vector result; + if (!metadata.op_type().empty()) { + result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\"")); + } + if (!metadata.op_name().empty()) { + result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\"")); + } + if (!metadata.source_file().empty()) { + result.push_back( + StrCat("source_file=\"", CEscape(metadata.source_file()), "\"")); + } + if (metadata.source_line() != 0) { + result.push_back(StrCat("source_line=", metadata.source_line())); + } + return Join(result, " "); +} + +string RandomDistributionToString(const RandomDistribution& distribution) { + return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); +} + +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } @@ -2856,36 +3092,30 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { const auto append_dims = [&](const std::vector& dims, const Shape& shape) { CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - for (int64 logical = 0; logical < dims.size(); ++logical) { - int64 physical = logical; - if (!shape.layout().minor_to_major().empty()) { - physical = LayoutUtil::Major(shape.layout(), logical); - } - result += dims[physical]; - } + StrAppend(&result, Join(dims, "")); }; // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". - std::vector lhs_dims(2 + dnums.spatial_dimensions().size()); + std::vector lhs_dims(2 + dnums.input_spatial_dimensions().size()); lhs_dims[dnums.input_batch_dimension()] = 'b'; lhs_dims[dnums.input_feature_dimension()] = 'f'; - for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { - lhs_dims[dnums.spatial_dimensions(i)] = StrCat(i); + for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) { + lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i); } std::vector rhs_dims(2 + dnums.kernel_spatial_dimensions().size()); rhs_dims[dnums.kernel_input_feature_dimension()] = "i"; rhs_dims[dnums.kernel_output_feature_dimension()] = "o"; - for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { + for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) { rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i); } - std::vector output_dims(2 + dnums.spatial_dimensions().size()); + std::vector output_dims(2 + dnums.output_spatial_dimensions().size()); output_dims[dnums.output_batch_dimension()] = 'b'; output_dims[dnums.output_feature_dimension()] = 'f'; - for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { - output_dims[dnums.spatial_dimensions(i)] = StrCat(i); + for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) { + output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } result += "dim_labels="; @@ -2897,6 +3127,30 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { return result; } +string HloInstruction::DotDimensionNumbersToString() const { + string result; + if (dot_dimension_numbers_ == nullptr) { + return result; + } + const DotDimensionNumbers& dnums = *dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result += "lhs_batch_dims="; + StrAppend(&result, Join(dnums.lhs_batch_dimensions(), ",")); + } + result += "lhs_contracting_dims="; + StrAppend(&result, Join(dnums.lhs_contracting_dimensions(), ",")); + + result += ","; + if (!dnums.rhs_batch_dimensions().empty()) { + result += "rhs_batch_dims="; + StrAppend(&result, Join(dnums.rhs_batch_dimensions(), ",")); + } + result += "rhs_contracting_dims="; + StrAppend(&result, Join(dnums.rhs_contracting_dimensions(), ",")); + + return result; +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5ff04a48882497ef546aa095c346f4318a61f02b..092105582e09889091b90eae522489b3732f199c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" @@ -83,12 +84,16 @@ class HloInstruction { // must contain all operands of the newly constructed instruction. // computation_map: a map from computation name to HloComputation*. This map // must contain all computations which the newly constructed instruction - // calls. If the instruction is a fusion instruction, then the fusion - // computation is added to this map and the module. + // calls. + // add_fused_computation: A function to call to add a fused + // computation. Used (clearly) when the instruction is a fusion + // instruction. static StatusOr> CreateFromProto( HloModule* module, const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, - tensorflow::gtl::FlatMap* computation_map); + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, @@ -155,6 +160,18 @@ class HloInstruction { const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + static std::unique_ptr CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers); + + // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 + // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS + // and the RHS must be of rank 2. + static std::unique_ptr CreateCanonicalDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); + // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to // reduce it to. @@ -164,13 +181,19 @@ class HloInstruction { // Creates a cross replica sum op. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, HloInstruction* operand); + const Shape& shape, + tensorflow::gtl::ArraySlice operands); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr CreateConvert(const Shape& shape, HloInstruction* operand); + // Creates a bitcast conversion instruction, where operand is the data to + // convert and shape is the target shape for the conversion. + static std::unique_ptr CreateBitcastConvert( + const Shape& shape, HloInstruction* operand); + // Creates an infeed instruction, which reads data of the given shape from the // Infeed interface of the device. static std::unique_ptr CreateInfeed(const Shape& shape, @@ -181,18 +204,28 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config); - // Creates a send instruction with the given channel id, which sends the - // operand data to a unique receive instruction in another computation that - // has the same channel id. + // Creates an asynchronous send instruction with the given channel id, which + // initiates sending the operand data to a unique receive instruction in + // another computation that has the same channel id. static std::unique_ptr CreateSend(HloInstruction* operand, int64 channel_id); - // Creates a receive instruction with the given channel id, which receives - // data of the given shape from a unique send instruction in another - // computation that has the same channel id. + // Blocks until data transfer for the Send instruction (operand) is complete. + // The operand must be kSend. + static std::unique_ptr CreateSendDone( + HloInstruction* operand); + + // Creates an asynchronous receive instruction with the given channel id, + // which allocates resources to receive data of the given shape from a unique + // send instruction in another computation that has the same channel id. static std::unique_ptr CreateRecv(const Shape& shape, int64 channel_id); + // Blocks until data transfer for the Recv instruction (operand) is complete + // and returns the receive buffer. The operand must be kRecv. + static std::unique_ptr CreateRecvDone( + HloInstruction* operand); + // Creates a slice instruction, where the operand is sliced by the given // start/limit indices. static std::unique_ptr CreateSlice( @@ -295,6 +328,11 @@ class HloInstruction { HloComputation* body, HloInstruction* init); + static std::unique_ptr CreateConditional( + const Shape& shape, HloInstruction* pred, + HloInstruction* true_computation_arg, HloComputation* true_computation, + HloInstruction* false_computation_arg, HloComputation* false_computation); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -302,6 +340,11 @@ class HloInstruction { static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); + static std::unique_ptr CreateFusion( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation); + // Creates a fusion instruction that represents backward convolution. This is // similar to CreateFusion, but with extra arguments indicating the window and // dimemsion mapping of the backward convolution. @@ -391,7 +434,7 @@ class HloInstruction { Status RemoveControlDependencyTo(HloInstruction* instruction); // Returns the set of control predecessors (successors) of this - // instruction. Control predecessors (sucessors) must execute before (after) + // instruction. Control predecessors (successors) must execute before (after) // the current instruction. const std::vector& control_predecessors() const { return control_predecessors_; @@ -593,6 +636,15 @@ class HloInstruction { void set_select(HloComputation* select); void set_scatter(HloComputation* scatter); + // Gets/sets the true and false HloComputation for Conditional. The setters + // should only be called by HloModule or HloComputation methods. + // + // Precondition: The instruction is a Conditional instruction. + HloComputation* true_computation() const; + HloComputation* false_computation() const; + void set_true_computation(HloComputation* true_computation); + void set_false_computation(HloComputation* false_computation); + // Returns a string for the signature of this instruction if considered as a // function, e.g. the signature of an F32 add is (F32, F32) -> F32. string SignatureString() const; @@ -853,6 +905,11 @@ class HloInstruction { return *window_; } + // Sets the window data in a windowed operation such as convolution. + void set_window(const Window& window) { + window_ = MakeUnique(window); + } + // Returns the padding configuration for a pad node. // // Precondition: opcode() == HloOpcode::kPad @@ -871,6 +928,15 @@ class HloInstruction { // Returns the dump string of the convolution dimension numbers. string ConvolutionDimensionNumbersToString() const; + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + CHECK(dot_dimension_numbers_ != nullptr); + return *dot_dimension_numbers_; + } + + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng @@ -962,11 +1028,6 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Returns the opcode string for this instruction. This is the result from - // HloOpcodeString plus, for fusion nodes, the fusion kind, separated by a - // ':'. - string ExtendedOpcodeStr() const; - // Returns a string identifier for this instruction. If no string identifier // has been explicitly set, then the identifier is the serialized pointer to // this instruction. @@ -1134,6 +1195,9 @@ class HloInstruction { // Describes the dimension numbers used for a convolution. std::unique_ptr convolution_dimension_numbers_; + // Describes the dimension numbers used for a dot. + std::unique_ptr dot_dimension_numbers_; + // Describes the [begin, end) index range for a slice. std::vector slice_starts_; std::vector slice_limits_; @@ -1177,6 +1241,10 @@ class HloInstruction { // kSelectAndScatter computations. kSelectComputationIndex = 0, kScatterComputationIndex = 1, + + // kConditional computations. + kTrueComputationIndex = 0, + kFalseComputationIndex = 1, }; // Outfeed configuration information, only present for kOutfeed. @@ -1224,6 +1292,13 @@ string ToString(HloInstruction::FusionKind kind); StatusOr StringToFusionKind( const string& kind_name); +// Custom (de)stringification functions for protos that live inside +// HloInstruction. +string PaddingConfigToString(const PaddingConfig& padding); +string OpMetadataToString(const OpMetadata& metadata); +string RandomDistributionToString(const RandomDistribution& distribution); +StatusOr StringToRandomDistribution(const string& name); + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // Map classes that guarantee a deterministic iteration order when the key is @@ -1231,6 +1306,9 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of // the hlo. +// +// Note that this cannot be used for HLO instructions across multiple modules +// since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, const HloInstruction* const& rhs) const { diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index ddb623332c905fe406473e0c1a7adcea9782fdd0..54788fa2daa428d71616858128d0f7269e617df1 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1068,8 +1068,11 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); @@ -1088,48 +1091,6 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { root2->operand(1)->operand(0)->shape())); } -TEST_F(HloInstructionTest, IsRandomFusable) { - auto shape = ShapeUtil::MakeShape(F32, {2, 2}); - { - auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(0.0))); - auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0))); - auto rng = builder.AddInstruction(HloInstruction::CreateRng( - shape, RandomDistribution::RNG_NORMAL, {const0, const1})); - - auto* computation = hlo_module->AddEntryComputation(builder.Build()); - computation->CreateFusionInstruction({rng, const0, const1}, - HloInstruction::FusionKind::kLoop); - - auto* root = computation->root_instruction(); - - EXPECT_EQ(HloOpcode::kFusion, root->opcode()); - } - { - auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(0.0))); - auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0))); - auto rng = builder.AddInstruction(HloInstruction::CreateRng( - shape, RandomDistribution::RNG_NORMAL, {const0, const1})); - builder.AddInstruction(HloInstruction::CreateUnary( - shape, HloOpcode::kNegate, rng)); - auto* computation = hlo_module->AddEntryComputation(builder.Build()); - computation->CreateFusionInstruction({rng, const0, const1}, - HloInstruction::FusionKind::kLoop); - - auto* root = computation->root_instruction(); - - EXPECT_EQ(HloOpcode::kFusion, root->operand(0)->opcode()); - } -} - - TEST_F(HloInstructionTest, CloneSuffixNames) { // Test that the suffix string added to cloned instructions is not // duplicated. Rather a numeric incrementing value should be appended. That @@ -1138,39 +1099,38 @@ TEST_F(HloInstructionTest, CloneSuffixNames) { // Test cloning the same instruction multiple times. auto foo = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo"); - EXPECT_EQ(foo->Clone()->name(), "%foo.clone"); - EXPECT_EQ(foo->Clone()->Clone()->name(), "%foo.clone2"); - EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "%foo.clone3"); + EXPECT_EQ(foo->Clone()->name(), "foo.clone"); + EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2"); + EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3"); // Test custom suffixes. - EXPECT_EQ(foo->Clone("bar")->name(), "%foo.bar"); - EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "%foo.bar2"); - EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), - "%foo.bar2.clone"); + EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone"); // Test instruction name with a dot. auto foo_baz = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.baz"); - EXPECT_EQ(foo_baz->Clone()->name(), "%foo.baz.clone"); + EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone"); // Test incrementing a large number after the suffix. auto foo_clone234 = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.clone234"); - EXPECT_EQ(foo_clone234->Clone()->name(), "%foo.clone235"); + EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235"); // Test a non-numeric string after the cloning suffix. auto foo_clonexyz = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz"); - EXPECT_EQ(foo_clonexyz->Clone()->name(), "%foo.clonexyz.clone"); + EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone"); // Test a name with multiple appearances of the suffix. auto foo_clone_clone3 = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3"); - EXPECT_EQ(foo_clone_clone3->Clone()->name(), "%foo.clone.clone4"); + EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4"); } TEST_F(HloInstructionTest, Stringification) { - // Tests stringification of a simple op, fusion, and while. + // Tests stringification of a simple op, fusion, while, and conditional. const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); @@ -1183,27 +1143,41 @@ TEST_F(HloInstructionTest, Stringification) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); EXPECT_EQ(dot->ToString(false, false), "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " - "%transpose)"); + "%transpose), lhs_contracting_dims=1,rhs_contracting_dims=0"); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); - EXPECT_EQ(fusion->ToString(false, false), - "%fusion = f32[5,20]{1,0} fusion:kTransposeDot(f32[5,10]{1,0} %x, " - "f32[20,10]{1,0} %y), calls=%fused_computation"); + EXPECT_EQ( + fusion->ToString(false, false), + "%fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " + "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); EXPECT_EQ(loop->ToString(false, false), "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), " "condition=%TransposeDot, body=%TransposeDot"); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + sout, pred, x, computation, x, computation)); + EXPECT_EQ(conditional->ToString(false, false), + "%conditional = f32[5,20]{1,0} conditional(pred[] %constant, " + "f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), " + "true_computation=%TransposeDot, false_computation=%TransposeDot"); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 4d4010b0253c57eec3587776308f0a5fbaa31304..992f55788b4900949f4994ba5b7be015bcd0d3de 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -87,6 +87,7 @@ HLO_MATCHER(Call); HLO_MATCHER(Ceil); HLO_MATCHER(Clamp); HLO_MATCHER(Concatenate); +HLO_MATCHER(Conditional); HLO_MATCHER(Constant); HLO_MATCHER(Convert); HLO_MATCHER(Convolution); @@ -121,6 +122,7 @@ HLO_MATCHER(Outfeed); HLO_MATCHER(Pad); HLO_MATCHER(Power); HLO_MATCHER(Recv); +HLO_MATCHER(RecvDone); HLO_MATCHER(Reduce); HLO_MATCHER(ReducePrecision); HLO_MATCHER(ReduceWindow); @@ -131,6 +133,7 @@ HLO_MATCHER(Rng); HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); +HLO_MATCHER(SendDone); HLO_MATCHER(ShiftLeft); HLO_MATCHER(ShiftRightLogical); HLO_MATCHER(ShiftRightArithmetic); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 659f3d8c26be97a45e5a219b5081334e4f5dcdab..6fe2134466ffaf1402e5ecbc81aea9aafe2a468b 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -35,14 +35,15 @@ namespace xla { HloModule::HloModule(const string& name, const VersionedComputationHandle& entry_computation_handle, const HloModuleConfig& config) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), config_(config), has_entry_computation_handle_(true), entry_computation_handle_(entry_computation_handle) {} -HloModule::HloModule(const string& name) : name_(name) {} +HloModule::HloModule(const string& name) + : name_(NameUniquer::GetSanitizedName(name)) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) - : name_(name), config_(config) {} + : name_(NameUniquer::GetSanitizedName(name)), config_(config) {} HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -174,12 +175,6 @@ string HloModule::ToString(bool include_large_constants) const { std::ostringstream s; s << "HloModule " << name() << ":\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { - // Fusion computations are emitted with their fusion instruction and - // therefore don't need to be emitted as a separate comptutation in the - // module. - if (computation->IsFusionComputation()) { - continue; - } if (computation == entry_computation()) { s << "ENTRY "; } @@ -296,9 +291,16 @@ StatusOr> HloModule::CreateFromProto( tensorflow::gtl::FlatMap computation_map; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, &computation_map)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr computation, + HloComputation::CreateFromProto( + module.get(), computation_proto, computation_map, + /*add_fused_computation=*/ + [&module](std::unique_ptr fused_computation) { + module->AddComputationInternal(std::move(fused_computation), + /*is_entry=*/false, + /*uniquify_names=*/false); + })); CHECK_NE(computation.get(), nullptr); TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); string computation_name = computation->name(); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 6469851791ddb66c6fb17aa8d7c80b04c879a67b..5141e7bc8d4cf0ef4cd83310772e0c5d66b5da12 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -85,7 +85,11 @@ class HloModule { std::unique_ptr Clone(const string& suffix = "clone") const; // Return a pointer to the entry computation of the module.. - HloComputation* entry_computation() const { + const HloComputation* entry_computation() const { + CHECK_NE(nullptr, entry_computation_); + return entry_computation_; + } + HloComputation* entry_computation() { CHECK_NE(nullptr, entry_computation_); return entry_computation_; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 8974deb530c2e4561b5ab57f43c65fd525db3617..822e2f1f53e5ee460b88c2241ecf7f6b91ef608b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -39,8 +39,8 @@ void HloModuleConfig::SetDefaultComputationLayout( } string HloModuleConfig::compilation_cache_key() const { - string key = tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_, - "::hybrid=", has_hybrid_result_); + string key = + tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_); StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 4a7ead9c104d2ed50d5c895b3cdf2d3767ae16e8..a5ee895e48448fbb8fa3879dc1b6764c1f9f6966 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -104,16 +104,6 @@ class HloModuleConfig { // Whether to enable HLO-level profiling. bool hlo_profiling_enabled_ = false; - // If this flag is true, the generated executable will return a ShapedBuffer - // holding the result of the computation. In a ShapedBuffer, tuples have their - // structure held in host memory and the element arrays (leaves of the tuple - // structure) stored in device memory. The ShapedBuffer is considered "hybrid" - // because its leaves are on device but its structure is stored on - // host. Otherwise, if this flag is false, the generated executable will - // return a DeviceMemoryBase where the result is held entirely in device - // memory. - bool has_hybrid_result_ = false; - // Module/graph-level seed handle. uint64 seed_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index d68fc20321152f6a2ede1234180bee0db110f503..f3f79357582ac7661a532e94031acdbca0b86784 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -52,12 +52,14 @@ namespace xla { V(kBatchNormInference, "batch-norm-inference") \ V(kBatchNormTraining, "batch-norm-training") \ V(kBitcast, "bitcast") \ + V(kBitcastConvert, "bitcast-convert") \ V(kBroadcast, "broadcast") \ V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil") \ V(kClamp, "clamp") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ + V(kConditional, "conditional") \ V(kConstant, "constant") \ V(kConvert, "convert") \ V(kConvolution, "convolution") \ @@ -97,6 +99,7 @@ namespace xla { V(kPower, "power") \ V(kReal, "real") \ V(kRecv, "recv") \ + V(kRecvDone, "recv-done") \ V(kReduce, "reduce") \ V(kReducePrecision, "reduce-precision") \ V(kReduceWindow, "reduce-window") \ @@ -108,6 +111,7 @@ namespace xla { V(kSelect, "select") \ V(kSelectAndScatter, "select-and-scatter") \ V(kSend, "send") \ + V(kSendDone, "send-done") \ V(kShiftLeft, "shift-left") \ V(kShiftRightArithmetic, "shift-right-arithmetic") \ V(kShiftRightLogical, "shift-right-logical") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 37009369797693dcd06647fad845bb0c004cec67..6f6e679a21870e46da85963c3b2998465ac43420 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -173,6 +173,19 @@ bool HloOrdering::UseIsBeforeValueDefinition( return true; } } + + // The use at a call occurs before values that are defined in the called + // computation. + if (use.instruction->opcode() == HloOpcode::kCall) { + const HloInstruction* call = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + call->to_apply())) { + VLOG(4) << " use is call " << use.instruction->name() + << " and def is in called computation"; + return true; + } + } + VLOG(4) << " use is not before value"; return false; } @@ -187,23 +200,6 @@ bool HloOrdering::LiveRangeStrictlyBefore( return false; } - // Live-out values from the module can never have ranges strictly before any - // other value. - if (a.live_out_of_module()) { - VLOG(4) << "a is live out of module"; - return false; - } - - // Live-out values of computations can never have ranges strictly before any - // other value in the computation (including values nested in - // subcomputations). - if (a.live_out_of_computation() && - call_graph_->InstructionIsNestedIn(b.defining_instruction(), - a.defining_instruction()->parent())) { - VLOG(4) << "a is live out of computation containing b"; - return false; - } - // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (!UseIsBeforeValueDefinition(use, b, dataflow)) { diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc new file mode 100644 index 0000000000000000000000000000000000000000..e944ad15139af0d2f98e8e68d3d48303f47ecf1c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_profile_printer.h" + +#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" + +namespace xla { +string HloProfilePrinter::ToString(const int64* counters, + double clock_rate_ghz) const { + string result; + + for (int computation_idx = 0; computation_idx < computation_infos_size_; + computation_idx++) { + const HloComputationInfo& computation = computation_infos_[computation_idx]; + const HloInstructionInfo* instructions_begin = computation.instructions; + const HloInstructionInfo* instructions_end = + computation.instructions + computation.instructions_size; + bool any_instruction_profiled = + std::any_of(instructions_begin, instructions_end, + [&](const HloInstructionInfo& instruction_info) { + return counters[instruction_info.profile_index] != 0; + }); + + if (!any_instruction_profiled) { + continue; + } + + // Once we start using this in AOT for real, we will probably need a more + // minimal version of HumanReadableProfileBuilder. + HumanReadableProfileBuilder builder( + computation.name, counters[computation.profile_index], clock_rate_ghz); + + for (const auto* instruction = instructions_begin; + instruction != instructions_end; instruction++) { + builder.AddOp( + /*op_name=*/instruction->long_name, + /*short_name=*/instruction->short_name, instruction->category, + counters[instruction->profile_index], instruction->flop_count, + instruction->transcendental_count, instruction->bytes_accessed, + instruction->optimal_seconds); + } + + result += builder.ToString(); + } + + return result; +} + +HloProfilePrinter::~HloProfilePrinter() { + if (deleter_) { + deleter_(computation_infos_, computation_infos_size_); + } +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.h b/tensorflow/compiler/xla/service/hlo_profile_printer.h new file mode 100644 index 0000000000000000000000000000000000000000..2f056490ae027872570f7a0821ee63114f49fab8 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.h @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +// Instances of this class can pretty-print profile counters gathered from +// running an XLA computation without having access to the backing module. +class HloProfilePrinter { + public: + // Holds meta information about an HloInstruction. + // + // The pointer-typed fields can be owning or non-owning -- this decision is + // manifested as the deleter_ function in the containing HloProfilePrinter. + struct HloInstructionInfo { + // Textual information for pretty printing. + const char* long_name; + const char* short_name; + const char* category; + + // Metrics computed by HloCostAnalysis. + float flop_count; + float transcendental_count; + float bytes_accessed; + float optimal_seconds; + + // The index into the profile counters array for the HloInstruction + // corresponding to this HloInstructionInfo. + int64 profile_index; + }; + + // Holds meta information about an HloComputation. + // + // The pointer-typed fields can be owning or non-owning -- this decision is + // manifested as the deleter_ function in the containing HloProfilePrinter. + struct HloComputationInfo { + const char* name; + + // The index into the profile counters array for the HloInstruction + // corresponding to this HloComputationInfo. + int64 profile_index; + + HloInstructionInfo* instructions; + int64 instructions_size; + }; + + HloProfilePrinter( + HloComputationInfo* computation_infos, int64 computation_infos_size, + int64 profile_counters_size, + std::function deleter = nullptr) + : computation_infos_(computation_infos), + computation_infos_size_(computation_infos_size), + profile_counters_size_(profile_counters_size), + deleter_(std::move(deleter)) {} + + HloProfilePrinter(HloProfilePrinter&& other) { + std::swap(other.computation_infos_, computation_infos_); + std::swap(other.computation_infos_size_, computation_infos_size_); + std::swap(other.deleter_, deleter_); + } + + HloProfilePrinter(const HloProfilePrinter&) = delete; + HloProfilePrinter& operator=(const HloProfilePrinter&) = delete; + + // Converts the profile counter sequence `counters` to a human readable string + // representation. + string ToString(const int64* counters, double clock_rate_ghz) const; + + // Returns the size of the profile buffer expected by this printer. + int64 profile_counters_size() const { return profile_counters_size_; } + + ~HloProfilePrinter(); + + private: + // The `computation_infos_` field can be owning or non-owning -- this decision + // is manifested as the deleter_ function. + HloComputationInfo* computation_infos_ = nullptr; + int64 computation_infos_size_ = 0; + int64 profile_counters_size_ = 0; + std::function deleter_; +}; +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index d7bdac9c86579f19afbba133772c2c50894853d1..553ec11f6f9a2997ab7113f9b8241e04c7fe20d5 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -30,11 +30,17 @@ namespace xla { class HloInstruction; -// A class for computing and representing reachability between HloInstructions. +// A class for representing reachability between HloInstructions. +// +// !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix +// and it is up to the user of the class to set the adjacency matrix such that +// it represents reachability, i.e. such that it is transitive. That the graph +// be transitive is thus not an invariant of this class, but it is required for +// the name of the class and its methods to make sense. class HloReachabilityMap { public: - // Sets up an empty reachable matrix for the full set of instructions - // specified in 'instructions'. + // Sets up a graph with no edges and where the nodes correspond to the given + // instructions. explicit HloReachabilityMap(const std::list& instructions); // Set the reachability set of 'instruction' to the union of the reachability @@ -42,17 +48,33 @@ class HloReachabilityMap { // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from // itself. Returns whether the reachability set of 'instruction' changed. + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // vector in the internal graph of this HloReachabilityMap for the given + // instruction and does not transitively update any other part of the + // adjacency matrix. bool SetReachabilityToUnion( tensorflow::gtl::ArraySlice inputs, const HloInstruction* instruction); // Sets entry so that IsReachable(a, b) will return true + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // matrix in the internal graph of this HloReachabilityMap to have an edge + // from a to b and does not transitively update any other part of the + // adjacency matrix. void SetReachable(const HloInstruction* a, const HloInstruction* b); // Returns true if "b" is reachable from "a" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; // Returns true if "b" is reachable from "a" or "a" is reachable from "b" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; private: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index c96df50e79a3c6d4ca5f8e7e0abec33cdfca1c70..1747790e63c6af997eea096b68e5525fdd9d131a 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -62,16 +62,11 @@ bool IsRematerializable(const HloInstruction* instruction) { case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: - case HloOpcode::kOutfeed: - case HloOpcode::kInfeed: case HloOpcode::kParameter: - case HloOpcode::kRecv: - case HloOpcode::kSend: - case HloOpcode::kTrace: case HloOpcode::kWhile: return false; default: - return true; + return !instruction->HasSideEffect(); } } @@ -571,7 +566,9 @@ Status MemoryUsageTracker::BeginInstruction(Item* item) { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } @@ -608,8 +605,9 @@ Status MemoryUsageTracker::EndInstruction() { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); - + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } @@ -1026,7 +1024,9 @@ StatusOr HloRematerialization::RematerializeComputation( HloInstruction* best = best_item->instruction; VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " - << memory_tracker.MemoryReducedIfRematerialized(best_item) << ")"; + << HumanReadableNumBytes( + memory_tracker.MemoryReducedIfRematerialized(best_item)) + << ")"; changed = true; remat_count++; @@ -1106,8 +1106,8 @@ StatusOr HloRematerialization::RematerializeComputation( net_instructions_added++; } - VLOG(3) << "memory_usage after rematerialization = " - << memory_tracker.memory_usage(); + VLOG(1) << "memory_usage after rematerialization = " + << HumanReadableNumBytes(memory_tracker.memory_usage()); } const CallSite* callsite = call_graph_node.GetCallSite(instruction); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index d88aa4bb567c6c5f6eab54f12239bf7040339c39..c9b57166af438ef19ae4f079b8ecc8ddd5aede00 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -323,6 +323,76 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { EXPECT_EQ(inner_computation->instruction_count(), 8); } +TEST_F(HloRematerializationTest, RngNotRematerialized) { + // Test that a single rng is not rematerialized: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] rng = rng(param) + // F32[1024] tanh = tanh(rng) + // F32[1024] exp = exp(rng) + // F32[1024] add_0 = add(rng, tanh) // LIVE: add_0 + rng + + // // tanh + exp + // + // F32[1024] add_1 = add(rng, add(exp, add_0)) // LIVE: add_1 + add_0 + + // // rng + tanh + exp + // + // F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 + + // // rng + tanh + exp + auto module = CreateNewModule(); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto rng = builder.AddInstruction(HloInstruction::CreateRng( + vec1024_shape_, RandomDistribution::RNG_BERNOULLI, {param})); + auto tanh = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kTanh, rng)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kExp, rng)); + auto add_0 = builder.AddInstruction( + HloInstruction::CreateBinary(vec1024_shape_, HloOpcode::kAdd, rng, tanh)); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, rng, + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, exp, add_0)))); + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, rng, + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, tanh, add_1)))); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto count_rngs = [](const HloComputation* computation) { + int64 rng_count = 0; + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kRng) { + ++rng_count; + } + } + return rng_count; + }; + // Before rematerialization there should be a single broadcast rng in + // the graph. + ASSERT_EQ(count_rngs(entry_computation), 1); + const int64 original_instruction_count = + entry_computation->instruction_count(); + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSERT_OK_AND_ASSIGN( + bool changed, HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), + module.get(), &sequence)); + EXPECT_TRUE(changed); + // The rng should not have been rematerialized. + EXPECT_EQ(count_rngs(entry_computation), 1); + // There should have been rematerialization. + EXPECT_GT(entry_computation->instruction_count(), original_instruction_count); +} + TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Test that a single instruction is rematerialized several times. Module: // diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index f463e57d995c0f0549872a1a0bf20a3ead626dc8..a6101bbe6075d62d7a9872c3d9005dce2865453e 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/service/hlo_runner.h" @@ -19,8 +20,6 @@ limitations under the License. #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -40,6 +39,14 @@ namespace se = ::perftools::gputools; namespace xla { +/*static*/ StatusOr> +HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options) { + HloModuleConfig config; + config.set_debug_options(debug_options); + return tools::Parse(hlo_string, config); +} + /*static*/ StatusOr> HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, const DebugOptions& debug_options) { @@ -115,11 +122,16 @@ HloRunner::~HloRunner() { StatusOr HloRunner::Execute( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments, - Shape* result_shape) { + Shape* result_shape, bool run_hlo_passes) { + if (run_hlo_passes) { + TF_ASSIGN_OR_RETURN( + module, backend().compiler()->RunHloPasses( + std::move(module), backend().default_stream_executor())); + } TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend().compiler()->Compile(std::move(module), - backend().default_stream_executor())); + backend().compiler()->RunBackend(std::move(module), + backend().default_stream_executor())); se::Stream stream(backend().default_stream_executor()); stream.Init(); @@ -131,15 +143,14 @@ StatusOr HloRunner::Execute( run_options.set_intra_op_thread_pool( backend().eigen_intra_op_thread_pool_device()); - HloExecutionProfile hlo_execution_profile; ServiceExecutableRunOptions service_run_options( run_options, backend().StreamBorrower(), backend().inter_op_thread_pool()); TF_ASSIGN_OR_RETURN( se::DeviceMemoryBase result, executable->ExecuteOnStream(&service_run_options, arguments, - &hlo_execution_profile)); - TF_RET_CHECK(stream.BlockHostUntilDone()); + /*hlo_execution_profile=*/nullptr)); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); allocations_.push_back(result); @@ -195,10 +206,12 @@ StatusOr> HloRunner::TransferFromDevice( StatusOr> HloRunner::ExecuteAndTransfer( std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes) { Shape result_shape; - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase device_base, - Execute(std::move(module), arguments, &result_shape)); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase device_base, + Execute(std::move(module), arguments, &result_shape, run_hlo_passes)); return TransferFromDevice(result_shape, device_base); } diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index a5732848c6b4191faf8d7b07c749132ca8b14413..a65c66fd4b6db858a532096a5ee466aa9bf0d844 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -35,7 +35,8 @@ namespace xla { // A base class for running an HloModule. This executes the given HloModule on a // certain backend directly without using the client interface. HloModule can be -// explicitly built, or loaded from a serialization file (e.g., hlo proto file). +// explicitly built, or loaded from a serialization file (e.g., hlo proto +// file), or parsed from a hlo textual IR string. class HloRunner { public: HloRunner(); @@ -44,6 +45,12 @@ class HloRunner { ~HloRunner(); + // Converts an HloModule from the given hlo textual IR string (in + // HloModule::ToString format). + static StatusOr> CreateModuleFromString( + const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options); + // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. Will try to parse the filename as binary proto, then try as // text proto if that fails. @@ -65,17 +72,21 @@ class HloRunner { // Executes the given module with given literals as input and returns the // result as a Literal. The LiteralPtr type accepts Literal* or // std::unique_ptr. + // + // If run_hlo_passes is false, the module will be executed without Hlo + // optimization. template StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals); + const tensorflow::gtl::ArraySlice literals, + bool run_hlo_passes = true); // Executes the given module and returns a global data handle. StatusOr Execute( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments, - Shape* result_shape); + Shape* result_shape, bool run_hlo_passes = true); // Transfers the given literal to the device and returns the data handle. StatusOr TransferToDevice( @@ -90,7 +101,8 @@ class HloRunner { StatusOr> ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice - arguments); + arguments, + bool run_hlo_passes = true); // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. @@ -112,14 +124,15 @@ class HloRunner { template StatusOr> HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals) { + const tensorflow::gtl::ArraySlice literals, + bool run_hlo_passes) { std::vector arguments; for (const auto& literal : literals) { TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument, TransferToDevice(*literal)); arguments.push_back(argument); } - return ExecuteAndTransfer(std::move(module), arguments); + return ExecuteAndTransfer(std::move(module), arguments, run_hlo_passes); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 8ccbcaeee4a9c9e94b344231953e20ac8f4b2053..0dc17392f1f520a415083c92b51db9d9abb321c0 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -31,6 +31,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +using ::tensorflow::strings::HumanReadableNumBytes; + namespace xla { StatusOr MinimumMemoryForSequence( @@ -375,6 +377,7 @@ StatusOr> CreateMemoryMinimizingSequence( // Note that this is just a heuristic. One obvious inaccuracy is that the // memory required for sub-computations might be different when considered // within the caller's context. But it's good enough for now. + VLOG(2) << "Computation: " << computation.name(); TF_ASSIGN_OR_RETURN( std::vector list_sequence, ListScheduler::Run(computation, points_to_analysis, size_function)); @@ -382,7 +385,7 @@ StatusOr> CreateMemoryMinimizingSequence( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, points_to_analysis, size_function)); - VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; + VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); TF_ASSIGN_OR_RETURN( std::vector dfs_sequence, @@ -391,13 +394,15 @@ StatusOr> CreateMemoryMinimizingSequence( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, size_function)); - VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; + VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); if (list_memory <= dfs_memory) { - VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; + VLOG(2) << "Chose min-memory list sequence: " + << HumanReadableNumBytes(list_memory); return list_sequence; } else { - VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; + VLOG(2) << "Chose min-memory dfs sequence: " + << HumanReadableNumBytes(dfs_memory); return dfs_sequence; } } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 0d019d22f5d4cd401c0fc5572f99636dec4f7383..447c2446668253c932b44b51b2db22bfd47f9957 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace xla { @@ -38,6 +39,15 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { } string HloSharding::ToString() const { + if (IsTuple()) { + std::vector parts; + parts.reserve(tuple_elements_.size()); + for (const HloSharding& element : tuple_elements_) { + parts.push_back(element.ToString()); + } + return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); + } + string result = StrCat("{", (replicated_ ? " replicated" : ""), (maximal_ ? " maximal" : "")); @@ -53,6 +63,11 @@ string HloSharding::ToString() const { } bool HloSharding::UsesDevice(int64 device) const { + if (IsTuple()) { + return std::any_of( + tuple_elements_.begin(), tuple_elements_.end(), + [&](const HloSharding& s) { return s.UsesDevice(device); }); + } const auto& devices = tile_assignment_; return replicated_ || std::find(devices.begin(), devices.end(), device) != devices.end(); @@ -61,6 +76,7 @@ bool HloSharding::UsesDevice(int64 device) const { std::vector HloSharding::TileIndexForDevice(int64 device) const { CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); + CHECK(!IsTuple()); std::vector ret_index; tile_assignment_.Each([&](tensorflow::gtl::ArraySlice index, int64 d) { if (d == device) { @@ -74,6 +90,7 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { int64 HloSharding::DeviceForTileIndex( tensorflow::gtl::ArraySlice index) const { CHECK(!replicated_); + CHECK(!IsTuple()); if (maximal_) { return *tile_assignment_.begin(); } @@ -82,7 +99,7 @@ int64 HloSharding::DeviceForTileIndex( } std::vector HloSharding::TileOffsetForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); + CHECK(!IsTuple()); std::vector index = TileIndexForDevice(device); if (maximal_) { @@ -97,7 +114,7 @@ std::vector HloSharding::TileOffsetForDevice(int64 device) const { } std::vector HloSharding::TileLimitForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); + CHECK(!IsTuple()); CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. std::vector index = TileIndexForDevice(device); @@ -108,14 +125,94 @@ std::vector HloSharding::TileLimitForDevice(int64 device) const { } StatusOr HloSharding::UniqueDevice() const { - if (!replicated_ && maximal_) { + if (IsTuple()) { + if (tuple_elements_.empty()) { + return tensorflow::errors::InvalidArgument( + "UniqueDevice() called on empty tuple"); + } + std::vector> results; + std::transform(tuple_elements_.begin(), tuple_elements_.end(), + std::back_inserter(results), + [](const HloSharding& s) { return s.UniqueDevice(); }); + if (std::all_of(results.begin(), results.end(), + [&](const StatusOr& s) { + return s.ok() && results[0].ok() && + s.ValueOrDie() == results[0].ValueOrDie(); + })) { + return results[0]; + } else { + return tensorflow::errors::InvalidArgument( + "Tuple did not contain a unique device"); + } + } + if (!replicated_ && maximal_ && !IsTuple()) { return static_cast(*tile_assignment_.begin()); } return tensorflow::errors::InvalidArgument( "UniqueDevice() called on sharding that executes on multiple devices"); } +bool HloSharding::HasUniqueDevice() const { + if (IsTuple()) { + return UniqueDevice().status().ok(); + } else { + return !IsReplicated() && IsTileMaximal(); + } +} + +Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { + if (!ShapeUtil::IsTuple(shape)) { + return tensorflow::errors::InvalidArgument( + StrCat("Sharding is tuple-shaped but validation shape is not.")); + } + // The easiest way to get the number of elements in a nested tuple is just to + // create a shape tree. We could call GetAsShapeTree, but that will try and + // apply our tuple_shardings_ to the shape tree, and that might cause a crash + // at this point as we haven't validated them. + ShapeTree bool_shape_tree(shape, false); + int64 num_leaves = + std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end()); + if (num_leaves != tuple_elements_.size()) { + return tensorflow::errors::InvalidArgument( + StrCat("Validation tuple shape has ", num_leaves, + " leaf elements, but this sharding contains ", + tuple_elements_.size(), " elements.")); + } + + // Now we've validated the number of tuple elements, it's safe to request a + // shape tree. + ShapeTree shape_tree = GetAsShapeTree(shape); + for (const auto& index_to_sharding : shape_tree.leaves()) { + Status status = index_to_sharding.second.ValidateNonTuple( + ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices); + if (!status.ok()) { + tensorflow::errors::AppendToMessage( + &status, StrCat("Note: While validating sharding tuple element ", + index_to_sharding.first.ToString(), " which is ", + index_to_sharding.second.ToString())); + return status; + } + } + return Status::OK(); +} + Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { + Status status = IsTuple() ? ValidateTuple(shape, num_devices) + : ValidateNonTuple(shape, num_devices); + if (!status.ok()) { + tensorflow::errors::AppendToMessage( + &status, StrCat("Note: While validating sharding ", ToString(), + " against shape ", ShapeUtil::HumanString(shape))); + } + return status; +} + +Status HloSharding::ValidateNonTuple(const Shape& shape, + int64 num_devices) const { + if (ShapeUtil::IsTuple(shape)) { + return tensorflow::errors::InvalidArgument( + StrCat("Validation shape is a tuple but sharding is not.")); + } if (replicated_) { return Status::OK(); } @@ -129,13 +226,11 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { // Don't overwrite a bad status, so we report the first error. if (status.ok()) { if (core >= num_devices) { - status = - tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "core ", core, " > ", num_devices, " in tile assignment")); + status = tensorflow::errors::InvalidArgument(StrCat( + "core ", core, " > ", num_devices, " in tile assignment")); } else if (seen_cores.count(core) != 0) { - status = - tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "core ", core, " is not unique in tile assignment")); + status = tensorflow::errors::InvalidArgument( + StrCat("core ", core, " is not unique in tile assignment")); } } seen_cores.insert(core); @@ -151,7 +246,8 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { // The tile rank must be the same as the input rank. if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank"); + "Tile rank is different to the input rank. sharding=", ToString(), + ", input_shape=", ShapeUtil::HumanString(shape)); } // The tile shape must not be the same as the input shape without maximal_ @@ -169,9 +265,9 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { auto tile_dim = tile_shape_.dimensions(i); auto shape_dim = shape.dimensions(i); if (tile_dim > shape_dim) { - return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "Tile is larger than input shape (dimension ", i, ", ", tile_dim, - " > ", shape_dim)); + return tensorflow::errors::InvalidArgument( + StrCat("Tile is larger than input shape (dimension ", i, ", ", + tile_dim, " > ", shape_dim)); } } @@ -181,10 +277,10 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { int64 expected_dim = CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i)); if (tile_assignment_.dimensions()[i] != expected_dim) { - return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "Tile assignment tensor has incorrect shape. Dimension ", i, - " expected ", expected_dim, " but got ", - tile_assignment_.dimensions()[i])); + return tensorflow::errors::InvalidArgument( + StrCat("Tile assignment tensor has incorrect shape. Dimension ", i, + " expected ", expected_dim, " but got ", + tile_assignment_.dimensions()[i])); } } @@ -193,9 +289,19 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { /*static*/ StatusOr HloSharding::FromProto( const OpSharding& proto) { - if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) { + std::vector tuple_shardings; + tuple_shardings.reserve(proto.tuple_shardings().size()); + for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) { + TF_ASSIGN_OR_RETURN(HloSharding sharding, + HloSharding::FromProto(tuple_sharding_proto)); + tuple_shardings.push_back(sharding); + } + return HloSharding(tuple_shardings); + } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { return Replicate(); - } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) { + } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || + proto.tile_assignment_devices().size() == 1) { return HloSharding(proto.tile_assignment_devices(0)); } // Some versions of gcc cannot infer the TileAssignment constructor from a @@ -212,6 +318,15 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { OpSharding HloSharding::ToProto() const { OpSharding result; + + if (IsTuple()) { + for (const HloSharding& element : tuple_elements_) { + *result.add_tuple_shardings() = element.ToProto(); + } + result.set_type(OpSharding::Type::OpSharding_Type_TUPLE); + return result; + } + *result.mutable_tile_shape() = tile_shape_; for (int64 dim : tile_assignment_.dimensions()) { result.add_tile_assignment_dimensions(dim); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index d7ada30c70bc3b41b3117375380eac2e883d9a9d..7263198385cf0c84b1dac1e15177dcac99adaafb 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/hash/hash.h" @@ -67,6 +68,29 @@ class HloSharding { // `num_tiles` tiles. static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles); + // Creates a new sharding for a tuple type. The given ShapeTree must have + // elements for every leaf shape contained in the tuple. + static HloSharding Tuple(const ShapeTree& sub_shardings) { + std::vector flattened_list; + flattened_list.reserve( + std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end())); + for (const auto& index_to_sharding : sub_shardings.leaves()) { + flattened_list.push_back(index_to_sharding.second); + } + return HloSharding(flattened_list); + } + + // Creates a new sharding for a tuple type. The requested tuple shape must not + // be nested. For nested tuples, use the ShapeTree overload. + static HloSharding Tuple(const Shape& tuple_shape, + tensorflow::gtl::ArraySlice shardings) { + CHECK(ShapeUtil::IsTuple(tuple_shape)); + CHECK(!ShapeUtil::IsNestedTuple(tuple_shape)); + std::vector flattened_list(shardings.begin(), shardings.end()); + CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape)); + return HloSharding(flattened_list); + } + // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); @@ -76,47 +100,93 @@ class HloSharding { // Validate that this sharding can be applied to a tensor with shape `shape`. Status Validate(const Shape& shape, int64 num_devices) const; + // Returns true if the sharding has tuple type. + bool IsTuple() const { return tuple_; } + // Returns true if the sharding is trivial: replicate on all devices. - bool IsReplicated() const { return replicated_; } + bool IsReplicated() const { + if (!IsTuple()) { + return replicated_; + } + return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), + [](const HloSharding& s) { return s.IsReplicated(); }); + } // Returns true if the tile size is the same as the input size. - bool IsTileMaximal() const { return maximal_; } + bool IsTileMaximal() const { + if (!IsTuple()) { + return maximal_; + } + return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), + [](const HloSharding& s) { return s.IsTileMaximal(); }); + } // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; // Returns the tile that should be executed on the given device. + // REQUIRES: !IsTuple() std::vector TileIndexForDevice(int64 device) const; // Returns the device that should execute the given tile. // It is an error to call this if is_replicated() is true. + // REQUIRES: !IsTuple() int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; // Given a device ID, returns the offset within the input space of the // tile that should be executed on the given core. This returns the lower // extent of the tile in the input space. + // REQUIRES: !IsTuple() std::vector TileOffsetForDevice(int64 device) const; // Given a device ID, returns the limit within the input space of the // tile that should be executed on the given core. This returns the upper // extent of the tile in the input space. + // REQUIRES: !IsTuple() std::vector TileLimitForDevice(int64 device) const; // Returns the single device this op operates on. - // Requires !Replicated() && IsTileMaximal(). + // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal() StatusOr UniqueDevice() const; // Returns true if this op only uses a single device. - bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); } + bool HasUniqueDevice() const; + + // Returns the ShapeTree containing the shardings for each element of this + // tuple, if IsTuple, or a ShapeTree with a single element containing this + // sharding. Only the leaf elements are populated. This creates a new + // ShapeTree object so is not cheap. + ShapeTree GetAsShapeTree(const Shape& shape) const { + if (IsTuple()) { + ShapeTree result(shape, HloSharding::Replicate()); + CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()), + tuple_elements_.size()); + auto it = tuple_elements_.begin(); + for (auto& index_to_sharding : result.leaves()) { + index_to_sharding.second = *it++; + } + return result; + } else { + return ShapeTree(shape, *this); + } + } bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) && - tile_assignment_ == other.tile_assignment_; + tile_assignment_ == other.tile_assignment_ && + tuple_elements_ == other.tuple_elements_; } bool operator!=(const HloSharding& other) const { return !(*this == other); } size_t Hash() const { + if (!tuple_) { + size_t h = 0; + for (const auto& element : tuple_elements_) { + h = tensorflow::Hash64Combine(h, element.Hash()); + } + return h; + } if (replicated_) { return 0; } @@ -131,33 +201,52 @@ class HloSharding { } // Gets the tile shape. - // It is an error to call this if IsTileMaximal() is true. + // REQUIRES: !IsTileMaximal() && !IsTuple() const Shape& tile_shape() const { return tile_shape_; } // Gets the tile assignment tensor. - // It is an error to call this if IsReplicated() is true. + // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } private: HloSharding() : replicated_(true), maximal_(true), + tuple_(false), tile_shape_(), tile_assignment_({0}) {} explicit HloSharding(int64 device_id) : replicated_(false), maximal_(true), + tuple_(false), tile_shape_(), tile_assignment_({1}, device_id) {} HloSharding(const Shape& tile_shape, const Array& tile_assignment) : replicated_(false), maximal_(false), + tuple_(false), tile_shape_(tile_shape), tile_assignment_(tile_assignment) {} + HloSharding(const std::vector& tuple_shardings) + : replicated_(false), + maximal_(false), + tuple_(true), + tile_assignment_({0}), + tuple_elements_(tuple_shardings) {} + + // Internal helper to validate a tuple sharding. + Status ValidateTuple(const Shape& shape, int64 num_devices) const; + // Internal helper to validate a non-tuple (leaf) sharding. + Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; bool replicated_; bool maximal_; + bool tuple_; Shape tile_shape_; Array tile_assignment_; + // Only non-empty when tuple_ is true, but because empty tuples are allowed + // may also be empty even then. This is a flattened list of all the leaf + // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order). + std::vector tuple_elements_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index d0a20471a0f22a5fa414b71bb5160eed7cdc431b..0c7487b3ac77ff181d44dd55ebcf2608feaf02ea 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -70,6 +70,11 @@ TEST_F(HloShardingTest, DevicePlacement) { /*num_devices=*/6)); EXPECT_IS_NOT_OK( sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/5)); + + ShapeTree shape_tree = + sharding.GetAsShapeTree(ShapeUtil::MakeShape(U32, {4})); + EXPECT_EQ(shape_tree.element({}), sharding); + EXPECT_TRUE(shape_tree.IsLeaf({})); } TEST_F(HloShardingTest, Tile) { @@ -132,6 +137,39 @@ TEST_F(HloShardingTest, Tile) { } } +TEST_F(HloShardingTest, NestedTuple) { + // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) + Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}), + ShapeUtil::MakeShape(F32, {4, 6}), + }); + + HloSharding tiled_sharding = HloSharding::Tile( + ShapeUtil::MakeShape(F32, {4, 3}), Array({{0, 1}})); + OpSharding proto; + proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); + *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); + *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto(); + *proto.add_tuple_shardings() = tiled_sharding.ToProto(); + HloSharding tuple_sharding = + HloSharding::FromProto(proto).ConsumeValueOrDie(); + + ShapeTree shape_tree = + tuple_sharding.GetAsShapeTree(nested_tuple_shape); + EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate()); + EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0)); + EXPECT_EQ(shape_tree.element({2}), tiled_sharding); + + EXPECT_IS_OK(tuple_sharding.Validate(nested_tuple_shape, /*num_devices=*/5)); + // Test should fail because tuple element count does not match. + EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeTupleShape({}), + /*num_devices=*/5)); + // Test should fail because the input type is not a tuple. + EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeShape(F32, {}), + /*num_devices=*/5)); +} + TEST_F(HloShardingTest, Hash) { auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) { if (a.Hash() != b.Hash()) { @@ -184,6 +222,51 @@ TEST_F(HloShardingTest, Hash) { MakeArray({2, 2}, {0, 3, 1, 2})); EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); } + + HloSharding default_sharding = HloSharding::Replicate(); + { + ShapeTree shape_tree(ShapeUtil::MakeTupleShape({}), + default_sharding); + HloSharding sharding1 = HloSharding::Replicate(); + HloSharding sharding2 = HloSharding::Tuple(shape_tree); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree shape_tree(ShapeUtil::MakeTupleShape({}), + default_sharding); + HloSharding sharding1 = HloSharding::Tuple(shape_tree); + HloSharding sharding2 = HloSharding::Tuple(shape_tree); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree shape_tree1( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree1.mutable_element({0}) = HloSharding::Replicate(); + ShapeTree shape_tree2( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); + HloSharding sharding1 = HloSharding::Tuple(shape_tree1); + HloSharding sharding2 = HloSharding::Tuple(shape_tree2); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree shape_tree1( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0); + ShapeTree shape_tree2( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); + HloSharding sharding1 = HloSharding::Tuple(shape_tree1); + HloSharding sharding2 = HloSharding::Tuple(shape_tree2); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 06abe007477dbcd00bcdc7f2656c4dece6d1cf74..101a710d1cad9401134fdfe1d0ec9df241bc01e1 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -58,8 +58,6 @@ TensorShapeProto GetTensorShape(const HloInstruction* instruction) { string GetDeviceName(int device) { return StrCat("/device/XLA:", device); } -} // namespace - void CleanNodeName(string* name) { name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); const string chars_to_replace = "<>[]"; @@ -70,6 +68,11 @@ void CleanNodeName(string* name) { std::replace_if(name->begin(), name->end(), pred, '_'); } +} // namespace + +HloTfGraphBuilder::HloTfGraphBuilder(const DebugOptions& debug_options) + : debug_options_(debug_options) {} + Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { VLOG(2) << "Adding computation " << computation.name(); for (auto embedded : computation.MakeEmbeddedComputationsList()) { @@ -90,24 +93,38 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction( if (ContainsKey(instruction_to_node_name_, instruction)) { return instruction_to_node_name_[instruction]; } + auto append = [](string* str, const string& other) { + if (str->empty()) { + *str = other; + } else if (!other.empty()) { + StrAppend(str, "/", other); + } + }; string node_name; + if (debug_options_.xla_hlo_tfgraph_device_scopes() && + instruction->has_sharding() && + instruction->sharding().HasUniqueDevice()) { + node_name = StrCat( + "dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie()); + } // If an instruction is fused, put it in the subgraph of the fusion; // otherwise, put it in the computation subgraph. const HloComputation* computation = instruction->parent(); if (computation->IsFusionComputation()) { - node_name = GetNodeNameForInstruction(computation->FusionInstruction()); + append(&node_name, + GetNodeNameForInstruction(computation->FusionInstruction())); } else { - node_name = computation->name(); + append(&node_name, computation->name()); if (!instruction->metadata().op_name().empty()) { // Always make computations contain TF ops but not the other way around. - StrAppend(&node_name, "/", instruction->metadata().op_name()); + append(&node_name, instruction->metadata().op_name()); } } string instruction_name = instruction->name(); if (instruction->opcode() == HloOpcode::kParameter) { StrAppend(&instruction_name, ".", instruction->parameter_number()); } - StrAppend(&node_name, "/", instruction_name); + append(&node_name, instruction_name); CleanNodeName(&node_name); auto ret = instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h index b2c578af912ac0b777d1bc72a198504735a6b845..9aa3e501d5f85e3b61b20555e3d13c5687f33f2f 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h @@ -17,6 +17,7 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -26,6 +27,8 @@ namespace hlo_graph_dumper { // This constructs a tensorflow graph for HLO computations. class HloTfGraphBuilder { public: + HloTfGraphBuilder(const DebugOptions& debug_options = DebugOptions()); + // Adds a computation to the graph. Status AddComputation(const HloComputation& computation); @@ -42,6 +45,7 @@ class HloTfGraphBuilder { Status AddInstruction(const HloInstruction* instruction); + DebugOptions debug_options_; tensorflow::GraphDef graph_def_; // This records instructions that have been visited. std::unordered_set visited_instructions_; @@ -49,9 +53,6 @@ class HloTfGraphBuilder { std::unordered_map instruction_to_node_name_; }; -// Cleans the node name to make it a valid name in a tensorflow graph. -void CleanNodeName(string* name); - } // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index e6cf0d37b8a0f42dc04cfaad067a4741bc803705..05b7dce3d1ecf935b80ba1cb46ef089b7b3b6f33 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -71,7 +71,7 @@ HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi) : id_(id), is_phi_(is_phi) { // The defining position is always the first element in the positions_ vector. - AddPosition(instruction, index); + positions_.push_back(HloPosition{instruction, index}); } bool HloValue::operator==(const HloValue& other) const { @@ -130,18 +130,14 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, CHECK_LE(operand_number, 2); return operand_number == 0 || index.empty(); - case HloOpcode::kCall: case HloOpcode::kTuple: // These instructions always pass through their operands transparently. return false; + case HloOpcode::kCall: case HloOpcode::kWhile: - // Though the while instructions passes through its operands, we return - // true because in SSA form there may be a Phi at the parameter of the - // while which is considered a use of its incoming value because the Phi - // input values are not passed through into the body computation. Because - // this function is used in both SSA and non-SSA forms of the analysis - // conservatively return true. + // Although call and while instructions pass through their operands, they + // are considered uses. return true; default: @@ -151,103 +147,58 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, } // namespace -void HloValue::AddPosition(HloInstruction* instruction, - const ShapeIndex& index) { - HloPosition new_position{instruction, index}; - - // The new position must not already exist in positions_. - for (const HloPosition& position : positions_) { - DCHECK_NE(position, new_position); - } - - positions_.push_back(std::move(new_position)); - - // Update uses. - for (HloInstruction* user : instruction->users()) { - for (int64 operand_number : user->OperandIndices(instruction)) { - if (MayUseOperandValue(operand_number, index, user)) { - HloUse new_use{user, operand_number, index}; - - // The new use must not already exist in uses_. - for (const HloUse& use : uses_) { - DCHECK_NE(use, new_use); - } - - uses_.push_back(std::move(new_use)); +void HloValue::SetPositionsAndComputeUses( + tensorflow::gtl::ArraySlice positions) { + CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once."; + + // The positions must be unique and should not contain the defining position + // as this is added at construction time. + for (const HloPosition& position_a : positions) { + DCHECK_NE(position_a, defining_position()); + for (const HloPosition& position_b : positions) { + if (&position_a != &position_b) { + DCHECK_NE(position_a, position_b); } } } - // Update liveout status of this HloValue. - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - live_out_of_module_ = true; - } - - if (instruction == instruction->parent()->root_instruction()) { - live_out_of_computation_ = true; - } -} + positions_.insert(positions_.end(), positions.begin(), positions.end()); -void HloValue::RemovePosition(HloInstruction* instruction, - const ShapeIndex& index) { - // The defining position cannot be removed. - CHECK(!(instruction == defining_instruction() && index == defining_index())); - - int64 size_before = positions_.size(); - positions_.erase( - std::remove_if(positions_.begin(), positions_.end(), - [instruction, &index](const HloPosition& position) { - return position.instruction == instruction && - position.index == index; - }), - positions_.end()); - // Only a single position should have been removed. - CHECK_EQ(positions_.size(), size_before - 1); - - // Update uses which referred to this position. - uses_.erase(std::remove_if(uses_.begin(), uses_.end(), - [instruction, &index](const HloUse& use) { - return use.instruction->operand( - use.operand_number) == instruction && - use.operand_index == index; - }), - uses_.end()); - - // Returns whether this value is contained in the given instruction's output. - auto is_contained_in = [this](const HloInstruction* instruction) { - for (const HloPosition& position : positions()) { - if (position.instruction == instruction) { - return true; - } + // Gather the computation roots at which this value appears. + tensorflow::gtl::FlatSet root_positions; + for (const HloPosition& position : positions_) { + if (position.instruction == + position.instruction->parent()->root_instruction()) { + root_positions.insert(position.instruction); } - return false; - }; - - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - // Value has been removed from a position in the entry root instruction. - live_out_of_module_ = - is_contained_in(module.entry_computation()->root_instruction()); - } - if (instruction == defining_instruction()->parent()->root_instruction()) { - // Value has been removed from the root of the computation the value has - // been defined in. - live_out_of_computation_ = - is_contained_in(defining_instruction()->parent()->root_instruction()); } -} -void HloValue::RecomputeUses() { - uses_.clear(); - for (const HloPosition& position : positions()) { + // Build vector of HloUses for the value. + for (const HloPosition& position : positions_) { for (HloInstruction* user : position.instruction->users()) { for (int64 operand_number : user->OperandIndices(position.instruction)) { - if (MayUseOperandValue(operand_number, position.index, user)) { - uses_.push_back(HloUse{user, operand_number, position.index}); + // Root instructions of computations are considered to be uses whether + // or not the root instruction itself actually uses the value. + if (MayUseOperandValue(operand_number, position.index, user) || + ContainsKey(root_positions, user)) { + HloUse new_use{user, operand_number, position.index}; + + // The new use must not already exist in uses_. + for (const HloUse& use : uses_) { + DCHECK_NE(use, new_use); + } + + uses_.push_back(std::move(new_use)); } } } + + // Update liveout status of this HloValue. + const HloModule& module = *position.instruction->parent()->parent(); + if (position.instruction == + module.entry_computation()->root_instruction()) { + live_out_of_module_ = true; + } } } diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index 6872bc76a82253b916e826aa1afabc3d309c1d12..2a711e8b42590c29d0aaab95dcf110063ada3182 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -121,6 +121,12 @@ class HloValue { HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); + // Sets the positions in the module at which the HloValue appears. Updates + // uses. Should be called once and only once. The defining position should not + // be included in 'positions' as this is set at construction time. + void SetPositionsAndComputeUses( + tensorflow::gtl::ArraySlice positions); + // Return a unique identifier for this HloValue. This value is used for stable // sorting and iteration Id id() const { return id_; } @@ -143,28 +149,15 @@ class HloValue { // Return the shape of this HloValue. const Shape& shape() const { return defining_position().shape(); } - // Add or remove a position at which the HloValue appears. The definition - // position can not be removed. The uses of the HloValue are updated. - void AddPosition(HloInstruction* instruction, const ShapeIndex& index); - void RemovePosition(HloInstruction* instruction, const ShapeIndex& index); - - // Remove all positions except the defining position. Updates uses. - void ClearPositions(); - // Return all positions of the HloValue in the module. const std::vector& positions() const { return positions_; } // Return all uses of the HloValue. const std::vector& uses() const { return uses_; } - void RecomputeUses(); - // Get whether this HloValue is live out of the module. bool live_out_of_module() const { return live_out_of_module_; } - // Get whether this HloValue is live out of the computation it is defined in. - bool live_out_of_computation() const { return live_out_of_computation_; } - bool operator==(const HloValue& other) const; bool operator!=(const HloValue& other) const; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index c1aa655401a2be68af943e2ed29c4ab99d341383..b8fd7a89efd4d86630eed1f29db5b7b1b7876d23 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -59,21 +59,27 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleConvert(HloInstruction* convert) override { - if (ShapeUtil::ElementIsComplex(convert->operand(0)->shape())) { - TF_RET_CHECK(ShapeUtil::ElementIsComplex(convert->shape())) - << "Unsupported complex->real kConvert"; - } return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); } + Status HandleBitcastConvert(HloInstruction* convert) override { + return CheckShape(convert, ShapeInference::InferBitcastConvertShape( + convert->operand(0)->shape(), + convert->shape().element_type())); + } + Status HandleCopy(HloInstruction* copy) override { return CheckUnaryShape(copy); } Status HandleDot(HloInstruction* dot) override { - return CheckBinaryShape(dot); + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferDotOpShape( + dot->operand(0)->shape(), dot->operand(1)->shape(), + dot->dot_dimension_numbers())); + return CheckShape(dot, expected); } Status HandleConvolution(HloInstruction* convolution) override { @@ -87,8 +93,12 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleCrossReplicaSum(HloInstruction* crs) override { - return CheckShape(crs, ShapeInference::InferCrossReplicaSumShape( - crs->operand(0)->shape())); + std::vector operand_shapes; + for (const HloInstruction* operand : crs->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape( + crs, ShapeInference::InferCrossReplicaSumShape(operand_shapes)); } Status HandleReducePrecision(HloInstruction* reduce_precision) override { @@ -141,9 +151,6 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleBitcast(HloInstruction* bitcast) override { - // Bitcasts can be any shape, as long as the size matches the operand size. - TF_RET_CHECK(shape_size_fn_(bitcast->shape()) == - shape_size_fn_(bitcast->operand(0)->shape())); return tensorflow::Status::OK(); } @@ -263,6 +270,15 @@ class ShapeVerifier : public DfsHloVisitor { xla_while->while_body()->ComputeProgramShape().result()); } + Status HandleConditional(HloInstruction* conditional) override { + TF_RETURN_IF_ERROR(CheckShape( + conditional, + conditional->true_computation()->ComputeProgramShape().result())); + return CheckShape( + conditional, + conditional->false_computation()->ComputeProgramShape().result()); + } + Status HandlePad(HloInstruction* pad) override { return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), @@ -270,12 +286,40 @@ class ShapeVerifier : public DfsHloVisitor { pad->padding_config())); } - Status HandleSend(HloInstruction*) override { - return tensorflow::Status::OK(); + Status HandleSend(HloInstruction* send) override { + TF_RET_CHECK(send->users().size() == 1); + const HloInstruction* send_done = send->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape( + send, ShapeUtil::MakeTupleShape( + {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); } - Status HandleRecv(HloInstruction*) override { - return tensorflow::Status::OK(); + Status HandleSendDone(HloInstruction* send_done) override { + TF_RET_CHECK(send_done->operands().size() == 1); + const HloInstruction* send = send_done->operand(0); + TF_RET_CHECK(send->opcode() == HloOpcode::kSend); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape(send_done, ShapeUtil::MakeNil()); + } + + Status HandleRecv(HloInstruction* recv) override { + TF_RET_CHECK(recv->users().size() == 1); + const HloInstruction* recv_done = recv->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv, + ShapeUtil::MakeTupleShape( + {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); + } + + Status HandleRecvDone(HloInstruction* recv_done) override { + TF_RET_CHECK(recv_done->operands().size() == 1); + const HloInstruction* recv = recv_done->operand(0); + TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv_done, recv->shape().tuple_shapes(0)); } Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { @@ -365,6 +409,19 @@ class ShapeVerifier : public DfsHloVisitor { instruction->opcode(), instruction->operands())); } + // Checks if the given two instructions shares the same channel id. + Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return FailedPrecondition( + "Expected to have the same channel id, actual channel ids are: %s " + "(%lld), %s (%lld)", + instr1->ToString().c_str(), instr1->channel_id(), + instr2->ToString().c_str(), instr2->channel_id()); + } + return tensorflow::Status::OK(); + } + // Returns the size of a Shape in bytes. const std::function shape_size_fn_; }; @@ -530,7 +587,7 @@ StatusOr HloVerifier::Run(HloModule* module) { // or ComputationLowerer::Visit() TF_RET_CHECK(instruction->dimensions().size() == ShapeUtil::Rank(instruction->operand(0)->shape())) - << "Broadcast HLO has invalid number of dimensions."; + << "Broadcast HLO has invalid number of dimensions."; } else if (instruction->opcode() == HloOpcode::kWhile) { auto* while_cond = instruction->while_condition(); auto* while_body = instruction->while_body(); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 0d1b7bc109c56bc4290ede09284c6d20142bdb08..ba901b99e4f3c72c84c1ecdf4e19e58ad9ab6506 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -33,7 +33,9 @@ namespace xla { switch (instruction.opcode()) { // Cheap instructions. case HloOpcode::kAdd: + case HloOpcode::kAnd: case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: case HloOpcode::kBroadcast: case HloOpcode::kCeil: case HloOpcode::kClamp: @@ -53,15 +55,14 @@ namespace xla { case HloOpcode::kInfeed: case HloOpcode::kIsFinite: case HloOpcode::kLe: - case HloOpcode::kAnd: - case HloOpcode::kNot: - case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kOutfeed: case HloOpcode::kPad: case HloOpcode::kReal: @@ -88,10 +89,11 @@ namespace xla { // Expensive instructions. case HloOpcode::kAtan2: - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: @@ -103,17 +105,19 @@ namespace xla { case HloOpcode::kMap: case HloOpcode::kParameter: case HloOpcode::kPower: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: case HloOpcode::kRemainder: case HloOpcode::kRng: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kSort: case HloOpcode::kTanh: case HloOpcode::kTrace: case HloOpcode::kWhile: - case HloOpcode::kSend: - case HloOpcode::kRecv: return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 6d5796a24b5209355debd80b912b7fa62d40837c..dc63a2224d659fa427d4d1a30c5dc0f94d643b36 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -69,13 +69,19 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { return pipeline.Run(hlo_module).status(); } -StatusOr> InterpreterCompiler::Compile( +StatusOr> InterpreterCompiler::RunHloPasses( + std::unique_ptr hlo_module, + se::StreamExecutor* /*stream_exec*/) { + VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); + TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); + return std::move(hlo_module); +} + +StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); - VLOG(1) << "Generate graph " << hlo_module->name(); - - TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); + VLOG(1) << "Run backend " << hlo_module->name(); // Typically you would visit the HLO graph, building up a compiled equivalent // In this case we are using an HloEvaluator at execution time, so we don't diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index cfdc9b6256569b0137784b0d1db846a5f2339a5d..278cf5184227ae25518b1d46c0e16e4cce7bd1a8 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -43,8 +43,12 @@ class InterpreterCompiler : public Compiler { InterpreterCompiler() {} ~InterpreterCompiler() override {} - StatusOr> Compile( - std::unique_ptr hlo_modules, + StatusOr> RunHloPasses( + std::unique_ptr hlo_module, + perftools::gputools::StreamExecutor* stream_exec) override; + + StatusOr> RunBackend( + std::unique_ptr hlo_module, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 86dee8462fd4fdda580ada892e244f19177fb3e5..9183a1d1bfb8c2f6e1933c004f9c9f5f9ad8eced 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -42,7 +42,8 @@ namespace sep = ::perftools::gputools::interpreter; InterpreterExecutable::InterpreterExecutable( std::unique_ptr hlo_module) - : Executable(std::move(hlo_module)) {} + : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, + /*hlo_profile_index_map=*/nullptr) {} InterpreterExecutable::~InterpreterExecutable() {} @@ -89,7 +90,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - HloComputation* computation = module().entry_computation(); + const HloComputation* computation = module().entry_computation(); if (computation->num_parameters() != arguments.size()) { return tensorflow::errors::Internal( "Mismatch between argument count and graph parameter count."); @@ -156,10 +157,5 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } -std::unique_ptr InterpreterExecutable::CreateCostAnalysis() - const { - return MakeUnique(ShapeSizeBytes); -} - } // namespace interpreter } // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index c69b0d036d1058a6b24ee609a9923895d3246eec..0e87eb90bff4b896fc4bc0efc4fa7b851631be6f 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -61,8 +61,6 @@ class InterpreterExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); - std::unique_ptr CreateCostAnalysis() const override; - private: TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); }; diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 0bb3259ef43915067e614e72038387e8300ecc41..f16651c969d4d982302a7c9ac9c4af066eb27521 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -85,7 +85,7 @@ bool InterpreterExecutor::HostCallback(Stream *stream, bool InterpreterExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { AsExecutorStream(dependent)->EnqueueTask( - [other]() { other->BlockHostUntilDone(); }); + [other]() { SE_CHECK_OK(other->BlockHostUntilDoneWithStatus()); }); AsExecutorStream(dependent)->BlockUntilDone(); return true; } @@ -100,9 +100,9 @@ bool InterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { return true; } -bool InterpreterExecutor::BlockHostUntilDone(Stream *stream) { +port::Status InterpreterExecutor::BlockHostUntilDoneWithStatus(Stream *stream) { AsExecutorStream(stream)->BlockUntilDone(); - return true; + return port::Status::OK(); } DeviceDescription *InterpreterExecutor::PopulateDeviceDescription() const { diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index c59b2ccb1505b78be0c459ac9311428d65cc7e44..d3753a6a65d64c3d77644367bbd82068d4cf3044 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -157,7 +157,7 @@ class InterpreterExecutor : public internal::StreamExecutorInterface { bool StartTimer(Stream *stream, Timer *timer) override; bool StopTimer(Stream *stream, Timer *timer) override; - bool BlockHostUntilDone(Stream *stream) override; + port::Status BlockHostUntilDoneWithStatus(Stream *stream) override; int PlatformDeviceCount() override { return 1; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 7eda7c2284c2457703fcfcd4226172e41dd4ae01..328afe42bad64713013f761a6819ae8a47a52e04 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1303,7 +1303,7 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } - // Copy the root instrucion's result if the it does not match the result + // Copy the root instruction's result if the it does not match the result // layout constraint if (constraints.ResultLayout() != nullptr && !constraints.ResultLayout()->MatchesLayoutInShape( @@ -1328,6 +1328,20 @@ Status LayoutAssignment::RunOnComputation( << ")"; VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); + // Clear existing layouts of the instructions. All layouts must be assigned by + // the LayoutAssignment pass, except for Infeed, Outfeed, Parameters and the + // computation result. The latter two are specified in computation_layout, so + // we only need to keep the existing layouts for Infeed and Outfeed. Clearing + // the layouts here avoids hiding potential bugs in the layout assignment pass + // that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kInfeed || + instruction->opcode() == HloOpcode::kOutfeed) { + continue; + } + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(points_to_analysis, computation); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index c39ff52230055ec322ecf77f8df8ebdea12cdb6c..d51c0d1dfb727801d6d2a8328eba60838373479f 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -131,10 +131,10 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); - auto constant_literal1 = test_utils::CreateR2LiteralWithLayout( - {{1.0, 2.0}, {3.0, 4.0}}, minor_to_major); - auto constant_literal2 = test_utils::CreateR2LiteralWithLayout( - {{5.0, 6.0}, {7.0, 8.0}}, minor_to_major); + auto constant_literal1 = Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); + auto constant_literal2 = Literal::CreateR2WithLayout( + {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); Shape ashape = constant_literal1->shape(); auto constant1 = builder.AddInstruction( @@ -181,12 +181,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { // Verify the layouts of a tuple are assigned properly (the element layouts // match their source). auto builder = HloComputation::Builder(TestName()); - auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {0, 1}))); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {1, 0}))); + auto constant0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); @@ -218,12 +218,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { TEST_F(LayoutAssignmentTest, TupleSelect) { // Verify layouts of a select with tuple operands is assigned properly. auto builder = HloComputation::Builder(TestName()); - auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {0, 1}))); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {1, 0}))); + auto constant0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple0 = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); auto tuple1 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index c27a8956a706febd1855854a2d0560754caf5c03..68c99256a246edcf43a8358f667fc4458b9b4fea 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -103,7 +103,7 @@ namespace { // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) -// where 'user' is a user of an alias of 'intruction' at 'index', and +// where 'user' is a user of an alias of 'instruction' at 'index', and // 'operand_index' is the operand index at which the alias appears in the // operand list of 'user'. std::vector> GetAllUsesOfInstructionAtIndex( @@ -215,7 +215,8 @@ bool CanShareOperandBufferWithUser( auto add_operand_it = std::find_if(add->operands().begin(), add->operands().end(), [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kDot || + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot || (operand->opcode() == HloOpcode::kFusion && operand->fusion_kind() == HloInstruction::FusionKind::kTransposeDot); @@ -242,6 +243,31 @@ bool CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kCall) { + // TODO(b/62548313): Remove when buffer assignment is module scoped and + // does not assign buffers to calls. + // Find called computation parameter associated with 'operand'. + const std::vector operand_indices = user->OperandIndices(operand); + if (operand_indices.size() > 1) { + return false; + } + CHECK_EQ(1, operand_indices.size()); + auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); + // Get all uses of 'operand' at 'index' in called computation. + auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index, + points_to_analysis); + + // Return true iff: + // *) There exists exactly one use of 'operand' in called computation. + // *) The unique use is by the root instruction of called computation. + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + auto* callee_root = user->to_apply()->root_instruction(); + return param_uses.size() == 1 && param_uses[0].first == callee_root && + callee_root->IsElementwiseOnOperand(param_uses[0].second); + } // Check if 'user' is element-wise. return user->IsElementwise(); } @@ -294,7 +320,8 @@ bool CanShareOperandBufferWithUser(HloInstruction* operand, auto add_operand_it = std::find_if(add->operands().begin(), add->operands().end(), [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kDot || + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot || (operand->opcode() == HloOpcode::kFusion && operand->fusion_kind() == HloInstruction::FusionKind::kTransposeDot); @@ -320,6 +347,31 @@ bool CanShareOperandBufferWithUser(HloInstruction* operand, std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kCall) { + // Get all uses of value defined by 'operand' at 'operand_index'. + const auto& uses = + dataflow.GetValueDefinedAt(operand, operand_index).uses(); + // Return true iff: + // *) There exists two uses of 'operand'. + // *) One use is by 'user' (caller). + // *) One use is by root instruction of called computation (callee root). + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + const bool found_caller_use = + std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + return use.instruction == user; + }) != uses.end(); + auto* callee_root = user->to_apply()->root_instruction(); + const bool found_elementwise_callee_use = + std::find_if( + uses.begin(), uses.end(), [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); + return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; + } // Check if 'user' is element-wise. return user->IsElementwise(); } diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index b5e15906d3c085f773eb46b543515a614e63c59a..2c2a02f6375343d67dfb155bbb03729ff6e490d2 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -277,8 +277,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto b = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); @@ -312,8 +315,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); + HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); @@ -415,5 +421,44 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_)); } +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, + *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, + *dataflow_analysis_)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..34f3419269abbc73cd0ddb13c723a8da38ab19ff --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_compiler.h" + +namespace xla { +StatusOr>> LLVMCompiler::Compile( + std::vector> modules, + std::vector> + stream_execs) { + std::vector> result; + for (size_t i = 0; i < modules.size(); i++) { + if (stream_execs[i].size() != 1) { + return Unimplemented( + "Model partitioning not implemented for the CPU/GPU compilers!"); + } + + TF_ASSIGN_OR_RETURN( + modules[i], RunHloPasses(std::move(modules[i]), stream_execs[i][0])); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + RunBackend(std::move(modules[i]), stream_execs[i][0])); + result.push_back(std::move(executable)); + } + + return {std::move(result)}; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index b2e72871c10192c84349b117797c7bd7e6ee251a..c5393cef4f961c5d04c32d0d4291732b8ec702f1 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -57,6 +57,21 @@ class LLVMCompiler : public Compiler { void RemovePostOptimizationHook() { user_post_optimization_hook_ = nullptr; } + // Bring in + // StatusOr> RunBackend( + // std::unique_ptr module, + // perftools::gputools::StreamExecutor* stream_exec) + // StatusOr> RunHloPasses( + // std::unique_ptr module, + // perftools::gputools::StreamExecutor* stream_exec) + using Compiler::RunBackend; + using Compiler::RunHloPasses; + + StatusOr>> Compile( + std::vector> modules, + std::vector> + stream_execs) override; + protected: ModuleHook user_pre_optimization_hook_; ModuleHook user_post_optimization_hook_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 075d4a1ab5e5f39394ade393d21525ca3e97136e..d878061f724de1c82f8285b0f082d0be4d5778df 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -48,6 +48,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -155,6 +156,30 @@ cc_library( ], ) +cc_library( + name = "vector_support_library", + srcs = ["vector_support_library.cc"], + hdrs = ["vector_support_library.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "@llvm//:core", + ], +) + +cc_library( + name = "kernel_support_library", + srcs = ["kernel_support_library.cc"], + hdrs = ["kernel_support_library.h"], + deps = [ + ":llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index bdddc232ef74dfa37e2d5cc780b0fe11e7bc8e76..21bca1d6beff5b2804531724b94b123d4523c173 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -83,7 +83,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, if (std::find(parameter_instructions.begin(), parameter_instructions.end(), &hlo) != parameter_instructions.end()) { - array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{})); + array->MarkInvariantOverWholeProgram(context_); } } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index e3f98ac13e76f0df465066422ca7918a0f218b60..7224bd689842d89563b374f3db3d4e314be18764 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -256,10 +256,10 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( llvm::Instruction* instruction) const { CHECK(llvm::isa(instruction) || llvm::isa(instruction)); + CHECK(!llvm::isa(instruction) || !is_invariant_) + << "Trying to create a store to an invariant IRArray."; for (const auto& kind_md_pair : metadata_) { - CHECK(kind_md_pair.first != llvm::LLVMContext::MD_invariant_load || - llvm::isa(instruction)); instruction->setMetadata(kind_md_pair.first, kind_md_pair.second); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 1ed7e99a829f5b0daa709913554d2300503ca33e..387d4629125cbb791840e943013188d14159908a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -229,9 +229,33 @@ class IrArray { AddMetadata(llvm::LLVMContext::MD_noalias, noalias); } - void AddInvariantLoad(llvm::MDNode* invariant_load) { - CHECK_NE(invariant_load, nullptr); - AddMetadata(llvm::LLVMContext::MD_invariant_load, invariant_load); + // Promises LLVM that the data pointed to by this IrArray never changes after + // it's first loaded. + // + // The temporal scope of this promise is the "whole program" from LLVM's point + // of view, but how this translates to HLOs differs between backends. + // + // In the single-threaded CPU backend, we emit one function that + // runs all the HLOs in sequence, so the whole program is the whole HLO + // module. + // + // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO + // in the entry computation). From LLVM's perspective, launching a new kernel + // is like launching a new program, and so the whole program is one top-level + // HLO. Since the scope of the promise is smaller than in the CPU backend, we + // can mark more things as invariant in the GPU backend. + // + // Marking loads as invariant is particularly helpful on GPUs because + // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's + // __ldg intrinsic). These loads use a special cache, and can be + // significantly faster than regular loads. + void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) { + if (is_invariant_) { + return; + } + is_invariant_ = true; + AddMetadata(llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(*context, {})); } const std::map& metadata() const { return metadata_; } @@ -261,6 +285,8 @@ class IrArray { // loads/stores for this array. They keys are the metadata kinds and the // values are the metadata nodes. std::map metadata_; + + bool is_invariant_ = false; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f6f9810c3af712912a972ce5c34b7058cb6c675 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +void KernelSupportLibrary::For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + If(ir_builder_->CreateICmpSLT(start, end), [&]() { + for_body_generator(start, /*is_first_iteration=*/true); + For(name, ir_builder_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { for_body_generator(iv, false); }); + }); +} + +void KernelSupportLibrary::For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& for_body_generator) { + if (peel_first_iteration) { + For(name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) { + for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration)); + }); + } else { + std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( + name, start, end, step, ir_builder_, + /*prevent_unrolling=*/prevent_unrolling_, + /*prevent_vectorization=*/prevent_vectorization_); + ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); + for_body_generator(loop->GetIndVarValue(), + /*is_first_iteration=*/ir_builder_->CreateICmpEQ( + loop->GetIndVarValue(), start)); + llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_); + } +} + +void KernelSupportLibrary::If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator) { + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(condition, "", ir_builder_); + ir_builder_->SetInsertPoint(&if_data.true_block->back()); + true_block_generator(); + ir_builder_->SetInsertPoint(&if_data.false_block->back()); + false_block_generator(); + llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); +} + +void KernelSupportLibrary::EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + KernelSupportLibrary::ArgumentVector arguments, + const std::function& + kernel_body_generator) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + llvm::Function* function = + module->getFunction(llvm_ir::AsStringRef(kernel_name)); + + int64 null_arg_idx = -1; + std::vector sanitized_args; + sanitized_args.reserve(arguments.size()); + for (int64 i = 0, e = arguments.size(); i < e; i++) { + if (arguments[i]) { + sanitized_args.push_back(arguments[i]); + } else { + CHECK_EQ(null_arg_idx, -1); + null_arg_idx = i; + } + } + + if (!function) { + VLOG(2) << "Generating kernel for " << kernel_name; + std::vector arg_types; + std::transform(sanitized_args.begin(), sanitized_args.end(), + std::back_inserter(arg_types), + [](llvm::Value* arg) { return arg->getType(); }); + + auto* function_type = llvm::FunctionType::get( + ir_builder->getVoidTy(), arg_types, /*isVarArg=*/false); + + function = llvm_ir::CreateFunction( + function_type, llvm::GlobalValue::InternalLinkage, + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, kernel_name, module); + + llvm::IRBuilder<>::InsertPointGuard guard(*ir_builder); + + auto* entry_bb = + llvm::BasicBlock::Create(ir_builder->getContext(), "entry", function); + auto* return_inst = llvm::ReturnInst::Create(ir_builder->getContext(), + /*retVal=*/nullptr, entry_bb); + // Set the insert point to before return_inst. + ir_builder->SetInsertPoint(return_inst); + + std::vector arg_values; + std::transform(function->arg_begin(), function->arg_end(), + std::back_inserter(arg_values), std::addressof); + if (null_arg_idx != -1) { + arg_values.insert(arg_values.begin() + null_arg_idx, nullptr); + } + kernel_body_generator(arg_values); + } else { + VLOG(3) << "Re-using kernel for " << kernel_name; + } + + ir_builder->CreateCall(function, llvm_ir::AsArrayRef(sanitized_args)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h new file mode 100644 index 0000000000000000000000000000000000000000..827e092a3fa9116c461716b27c309033f7988745 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -0,0 +1,182 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ + +#include + +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { +// A thin wrapper around llvm_loop.h to make code generating structured control +// flow more readable. +class KernelSupportLibrary { + public: + // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. + // If `prevent_unrolling` is true then unrolling is explicitly disabled on + // every loop generated by this instance of KernelSupportLibrary. + explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder, + bool prevent_unrolling = true, + bool prevent_vectorization = true) + : ir_builder_(ir_builder), + prevent_unrolling_(prevent_unrolling), + prevent_vectorization_(prevent_vectorization) {} + + // Generates the following control flow structure: + // + // if (`start` < `end`) { + // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`; + // for (i64 i = `start` + `step`; i s< `end`; i += `step`) + // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; + // } + void For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& + for_body_generator); + + void For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& + for_body_generator) { + For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } + + // Generates the following control flow structure if `peel_first_iteration` is + // true: + // + // if (`start` < `end`) { + // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`; + // for (i64 i = `start` + `step`; i s< `end`; i += `step`) + // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`; + // } + // + // and the following if `peel_first_iteration` is false: + // + // for (i64 i = `start`; i s< `end`; i += `step`) + // `for_body_generator(/*ind_var=*/,i, + // /*is_first_iteration=*/,(i != `start`))`; + void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); + + void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + For(name, /*start=*/start, /*end=*/end, + /*step=*/ir_builder_->getInt64(step), peel_first_iteration, + for_body_generator); + } + + void For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + For(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + } + + void For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } + + // Generates the following control flow structure: + // + // if (`condition`) + // `true_block_generator()`; + // else + // `false_block_generator()`; + void If(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() {}); + + using ArgumentVector = tensorflow::gtl::ArraySlice; + + // Generates the following control flow structure: + // + // define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) { + // kernel_body_generator({arg0, arg1, ... arg`arguments.size()`}); + // } + // + // ... + // call @`kernel_name`(arguments[0], arguments[1] ...) + // ... + // + // If a function called `kernel_name` is already present in the module then + // that function is re-used. In that sense we're using the llvm::Module as a + // cache of outlined kernels, keyed by function name. + // + // If any of the values in `arguments` is nullptr (i.e. a nullptr + // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass + // in a nullptr llvm::Value* in its position to `kernel_body_generator`. + // Currently we only support at most one nullptr value in `arguments`. + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + ArgumentVector arguments, + const std::function& kernel_body_generator); + + // Thin wrappers around the more general EmitAndCallOutlinedKernel above. + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + const std::function& + kernel_body_generator) { + EmitAndCallOutlinedKernel( + enable_fast_math, optimize_for_size, ir_builder, kernel_name, + {arg0, arg1, arg2}, [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], args[2]); + }); + } + + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + llvm::Value* arg3, + const std::function& kernel_body_generator) { + EmitAndCallOutlinedKernel( + enable_fast_math, optimize_for_size, ir_builder, kernel_name, + {arg0, arg1, arg2, arg3}, [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], args[2], args[3]); + }); + } + + private: + llvm::IRBuilder<>* ir_builder_; + bool prevent_unrolling_; + bool prevent_vectorization_; +}; +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 83d35cb9efca0c27765045ce214e0e1060b18ed0..7b227ce294176cfbbf7308bbf65afe21814f3dea 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -34,21 +34,24 @@ namespace llvm_ir { ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* step, bool prevent_unrolling) + llvm::Value* step, bool prevent_unrolling, + bool prevent_vectorization) : prefix_(prefix.ToString()), suffix_(suffix.ToString()), start_index_(start_index), end_index_(end_index), step_(step), insert_before_bb_(nullptr), - prevent_unrolling_(prevent_unrolling) {} + prevent_unrolling_(prevent_unrolling), + prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling) { - std::unique_ptr loop(new ForLoop( - prefix, /*suffix=*/"", start_index, end_index, step, prevent_unrolling)); + bool prevent_unrolling, bool prevent_vectorization) { + std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, + end_index, step, prevent_unrolling, + prevent_vectorization)); loop->Emit(ir_builder); return loop; } @@ -127,14 +130,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { ir_builder->CreateStore(indvar_inc, indvar_address); llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_); - if (prevent_unrolling_) { - const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; - llvm::LLVMContext* ctx = &back_branch->getContext(); - + std::vector loop_metadata = GetLoopMetadata(ir_builder); + if (!loop_metadata.empty()) { + llvm::LLVMContext* ctx = &start_index_->getContext(); auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None); - auto no_unroll_node = llvm::MDNode::get( - *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}); - auto loop_id = llvm::MDNode::get(*ctx, {temp_node.get(), no_unroll_node}); + loop_metadata.insert(loop_metadata.begin(), temp_node.get()); + auto loop_id = llvm::MDNode::get(*ctx, loop_metadata); loop_id->replaceOperandWith(0, loop_id); back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id); } @@ -143,6 +144,27 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { ir_builder->SetInsertPoint(exit_bb_); } +std::vector ForLoop::GetLoopMetadata( + llvm::IRBuilder<>* ir_builder) { + const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; + const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; + llvm::LLVMContext* ctx = &start_index_->getContext(); + + std::vector result; + if (prevent_unrolling_) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)})); + } + + if (prevent_vectorization_) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName), + llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); + } + + return result; +} + string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_)); } @@ -156,23 +178,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), - prevent_unrolling); + prevent_unrolling, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr loop(new ForLoop( /*prefix=*/name_, suffix, start_index, end_index, stride, - prevent_unrolling)); + prevent_unrolling, prevent_vectorization)); loop->Emit(ir_builder_); if (outer_loop_preheader_bb_ == nullptr) { @@ -191,20 +215,24 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), prevent_unrolling); + ir_builder_->getInt64(end_index), prevent_unrolling, + prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), prevent_unrolling); + ir_builder_->getInt64(stride), prevent_unrolling, + prevent_vectorization); } IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 90f7c7df9e22d6404e9fdad2ce210506583bd427..20069ce5a28184a5a9216d1a3751d1cee547727d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -71,12 +71,10 @@ class ForLoop { // // If `prevent_unrolling` is true then emit metadata that directs LLVM to not // unroll the generated loop. - static std::unique_ptr EmitForLoop(tensorflow::StringPiece prefix, - llvm::Value* start_index, - llvm::Value* end_index, - llvm::Value* step, - llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = false); + static std::unique_ptr EmitForLoop( + tensorflow::StringPiece prefix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, + bool prevent_unrolling = false, bool prevent_vectorization = false); // The names of the blocks follow LLVM's conventions. Control flow amongst the // blocks for the example C code looks like: @@ -130,7 +128,7 @@ class ForLoop { ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, - bool prevent_unrolling); + bool prevent_unrolling, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* ir_builder); @@ -142,6 +140,10 @@ class ForLoop { // they are set. string GetQualifiedName(tensorflow::StringPiece name); + // Return a list of metadata nodes that should be associated with the + // llvm::Loop for this `ForLoop`. + std::vector GetLoopMetadata(llvm::IRBuilder<>* ir_builder); + string prefix_; string suffix_; llvm::Value* start_index_; @@ -160,6 +162,7 @@ class ForLoop { llvm::BasicBlock* exit_bb_; llvm::Value* indvar_; bool prevent_unrolling_; + bool prevent_vectorization_; TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); }; @@ -185,24 +188,28 @@ class ForLoopNest { std::unique_ptr AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. std::unique_ptr AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 956c0d5f05288e32c626f247ce8356c60d17808d..9a0c94b1c73c48682c1e868d4518b3797b01bbed 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/Target/TargetOptions.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -141,6 +142,13 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, return llvm::Type::getInt8Ty(module->getContext()); case S16: case U16: + case BF16: + // For BF16 we just need some type that is 16 bits wide so that it will + // take up the right amount of space in memory. LLVM does not have a BF16 + // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so + // we can't map it directly to an LLVM type. We will not map a BF16 + // addition to an addition on this type (int16) - this is just the type + // used for storage. return llvm::Type::getInt16Ty(module->getContext()); case S32: case U32: @@ -163,8 +171,9 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // z, and reinterpret_cast(z)[1] shall designate the // imaginary part of z. return llvm::StructType::create( - "complex64", llvm::Type::getFloatTy(module->getContext()), - llvm::Type::getFloatTy(module->getContext())); + {llvm::Type::getFloatTy(module->getContext()), + llvm::Type::getFloatTy(module->getContext())}, + "complex64", /*isPacked=*/true); } return cplx_t; } @@ -178,6 +187,21 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, } } +int GetSizeInBits(llvm::Type* type) { + const llvm::StructType* struct_ty = llvm::dyn_cast(type); + if (struct_ty) { + CHECK(struct_ty->isPacked()); + int bits = 0; + for (auto element_type : struct_ty->elements()) { + bits += GetSizeInBits(element_type); + } + return bits; + } + int bits = type->getPrimitiveSizeInBits(); + CHECK_GT(bits, 0) << "type is not sized"; + return bits; +} + llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); if (ShapeUtil::IsTuple(shape)) { @@ -263,6 +287,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); break; + case BF16: + value = llvm::ConstantInt::get( + ir_element_type, + tensorflow::bit_cast(literal.Get(*multi_index))); + break; case F64: value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); @@ -537,6 +566,14 @@ void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) { builder->SetInsertPoint(blk, blk->getFirstInsertionPt()); } +void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) { + if (llvm::Instruction* terminator = blk->getTerminator()) { + builder->SetInsertPoint(terminator); + } else { + builder->SetInsertPoint(blk); + } +} + llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, llvm::IRBuilder<>* builder) { auto size = rotand->getType()->getPrimitiveSizeInBits(); @@ -620,14 +657,27 @@ std::map MergeMetadata( return result; } +static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-"); + + tensorflow::mutex_lock lock(mu); + return uniquer->GetUniqueName(prefix); +} + Status DumpIRToDirectory(const string& directory_name, const string& hlo_module_name, const llvm::Module& llvm_module, bool optimized) { - string safe_file_name_base = SanitizeFileName(hlo_module_name); + // We can end up compiling different modules with the same name when using + // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously + // dumped from the same process in such cases. + string unique_and_safe_file_name = GetProcessUniqueIrFileName( + tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", + optimized ? "with" : "no", "-opt")); + string ir_file_name = tensorflow::io::JoinPath( directory_name, - tensorflow::strings::StrCat("ir-", safe_file_name_base, "-", - optimized ? "with" : "no", "-opt.ll")); + tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); std::unique_ptr f; TF_RETURN_IF_ERROR( @@ -638,5 +688,32 @@ Status DumpIRToDirectory(const string& directory_name, return f->Close(); } +llvm::Function* CreateFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + bool enable_fast_math, bool optimize_for_size, + tensorflow::StringPiece name, + llvm::Module* module) { + llvm::Function* function = + llvm::Function::Create(function_type, linkage, AsStringRef(name), module); + function->setCallingConv(llvm::CallingConv::C); + function->addFnAttr("no-frame-pointer-elim", "false"); + + if (enable_fast_math) { + function->addFnAttr("unsafe-fp-math", "true"); + function->addFnAttr("no-infs-fp-math", "true"); + function->addFnAttr("no-nans-fp-math", "true"); + function->addFnAttr("no-signed-zeros-fp-math", "true"); + } + + // Add the optize attribute to the function if optimizing for size. This + // controls internal behavior of some optimization passes (e.g. loop + // unrolling). + if (optimize_for_size) { + function->addFnAttr(llvm::Attribute::OptimizeForSize); + } + + return function; +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 304192b58e9331c2544f973bf65299111122aea8..6bdc6a01a2b487df3dd80a02e67f5bcf62dead31 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -129,6 +129,9 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, llvm::Module* module); +// Returns the type size in bits. If "type" is a struct, it must be packed. +int GetSizeInBits(llvm::Type* type); + // Returns the LLVM type which represents the given XLA shape. For example, // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]]. llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module); @@ -243,6 +246,8 @@ llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper, void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder); +void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder); + // Create a bitwise rotation of `rotand` by `rotor`. llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, llvm::IRBuilder<>* builder); @@ -276,6 +281,12 @@ Status DumpIRToDirectory(const string& directory_name, const string& hlo_module_name, const llvm::Module& llvm_module, bool optimized); +llvm::Function* CreateFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + bool enable_fast_math, bool optimize_for_size, + tensorflow::StringPiece name, + llvm::Module* module); + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/ops.h index 11e84d9cb5defbcb87a8f696d56c139686c960d8..f72f482e3128c61e53cc454e7da8b5795ba6f695 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.h @@ -40,11 +40,24 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, inline bool CanEmitFusedDynamicUpdateSliceInPlace( HloInstruction* fusion, const BufferAssignment& assignment) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); - return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop && - fusion->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice && - CanUpdateDynamicSliceInPlace(fusion->fused_expression_root(), - assignment); + HloInstruction* fused_root = fusion->fused_expression_root(); + if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice || + fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) { + return false; + } + // Walk DynamicUpdateSlice operand(0) to fused parameter and get its + // associated operand. See if it shares an allocation with this operand. + HloInstruction* fusion_operand; + ShapeIndex index; + std::tie(fusion_operand, index) = + fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex(); + if (fusion_operand->opcode() != HloOpcode::kParameter) { + return false; + } + auto* operand = fusion->operand(fusion_operand->parameter_number()); + return assignment.HasAllocationAt(operand, index) && + assignment.HasAllocationAt(fusion, {}) && + assignment.SharesSliceAtIndex(fusion, {}, operand, index); } // Emits IR for running the given dynamic-update-slice op in-place -- that is, diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f6d8483da88ba4bf3f26961c0cbc8d855faa82c --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc @@ -0,0 +1,280 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, + int64 vector_size, + llvm::IRBuilder<>* ir_builder, + std::string name) + : vector_size_(vector_size), + primitive_type_(primitive_type), + ir_builder_(ir_builder), + name_(std::move(name)) { + scalar_type_ = llvm_ir::PrimitiveTypeToIrType( + primitive_type, ir_builder_->GetInsertBlock()->getModule()); + scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); + vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); + vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); +} + +llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { + CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + return MulInternal(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs, + llvm::Value* rhs) { + if (scalar_type_->isFloatingPointTy()) { + return ir_builder()->CreateFMul(lhs, rhs, name()); + } else { + return ir_builder()->CreateMul(lhs, rhs, name()); + } +} + +llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { + CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + return AddInternal(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, + llvm::Value* rhs) { + if (scalar_type_->isFloatingPointTy()) { + return ir_builder()->CreateFAdd(lhs, rhs, name()); + } else { + return ir_builder()->CreateAdd(lhs, rhs, name()); + } +} + +llvm::Value* VectorSupportLibrary::ComputeOffsetPointer( + llvm::Value* base_pointer, llvm::Value* offset_elements) { + if (base_pointer->getType() != scalar_pointer_type()) { + base_pointer = ir_builder()->CreateBitCast(base_pointer, + scalar_pointer_type(), name()); + } + return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements}, + name()); +} + +llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) { + if (pointer->getType() != vector_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name()); + } + return ir_builder()->CreateAlignedLoad( + pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); +} + +llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) { + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + return ir_builder()->CreateAlignedLoad( + pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); +} + +void VectorSupportLibrary::StoreVector(llvm::Value* value, + llvm::Value* pointer) { + if (pointer->getType() != vector_pointer_type()) { + pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type()); + } + ir_builder()->CreateAlignedStore( + value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); +} + +void VectorSupportLibrary::StoreScalar(llvm::Value* value, + llvm::Value* pointer) { + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + ir_builder()->CreateAlignedStore( + value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); +} + +llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) { + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + return ir_builder()->CreateVectorSplat( + vector_size(), ir_builder()->CreateLoad(pointer), name()); +} + +llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { + llvm::SmallVector mask(vector_size(), nullptr); + for (unsigned i = vector_size(); i != 1; i >>= 1) { + // On every iteration, we shuffle half of the remaining lanes to the top + // half of shuffle, and add two old and the new vector. + + for (unsigned j = 0; j < vector_size(); ++j) { + if (j < (i / 2)) { + mask[j] = ir_builder()->getInt32(i / 2 + j); + } else { + mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty()); + } + } + + llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector( + vector, llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask), ""); + vector = Add(vector, half_remaining_lanes); + } + + return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0), + name()); +} + +llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, + llvm::Value* rhs) { + CHECK_EQ(lhs->getType(), vector_type()); + CHECK_EQ(rhs->getType(), vector_type()); + CHECK_EQ(vector_size() % 2, 0); + + llvm::SmallVector mask_a, mask_b; + + // Adding the values shuffled using mask_a and mask_b gives us the + // AVX-style horizontal add we want. The masks work as documented + // in https://llvm.org/docs/LangRef.html#shufflevector-instruction + // + // Here are the masks for vector_width() == 8: + // + // index: |0 |1 |2 | 3 |4 |5 | 6 | 7 + // --------+--+--+--+---+--+--+---+--- + // mask_a: |0 |2 |8 |10 |4 |6 |12 |14 + // mask_b: |1 |3 |9 |11 |5 |7 |13 |16 + // + // So, as an example, the value at lane 3 of the result vector is + // the result of adding lane 10 and lane 11 in the combined lhs++rhs + // vector, which are the lanes 2 and 3 in the rhs vector. + for (int i = 0; i < vector_size(); i += 2) { + int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2); + mask_a.push_back(ir_builder()->getInt32(increment + i)); + mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + } + for (int i = 0; i < vector_size(); i += 2) { + int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size(); + mask_a.push_back(ir_builder()->getInt32(increment + i)); + mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + } + + llvm::Value* shuffle_0 = ir_builder()->CreateShuffleVector( + lhs, rhs, llvm::ConstantVector::get(mask_a)); + llvm::Value* shuffle_1 = ir_builder()->CreateShuffleVector( + lhs, rhs, llvm::ConstantVector::get(mask_b)); + + return Add(shuffle_0, shuffle_1); +} + +llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) { + llvm::SmallVector mask; + for (int i = 0; i < vector_size() / 2; i++) { + mask.push_back(ir_builder()->getInt32(i)); + } + + return ir_builder()->CreateShuffleVector(vector, + llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); +} + +llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) { + llvm::SmallVector mask; + for (int i = 0; i < vector_size() / 2; i++) { + mask.push_back(ir_builder()->getInt32(i + vector_size() / 2)); + } + + return ir_builder()->CreateShuffleVector(vector, + llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); +} + +std::vector VectorSupportLibrary::ComputeHorizontalSums( + std::vector vectors, llvm::Value* init_values) { + // TODO(sanjoy): Move this magic constant to TargetMachineFeatures. + const int kAvxVectorWidth = 8; + if (vector_size() == kAvxVectorWidth && vectors.size() == kAvxVectorWidth) { + return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values); + } + + std::vector result; + std::transform(vectors.begin(), vectors.end(), std::back_inserter(result), + [this](llvm::Value* vector) { return AddReduce(vector); }); + if (init_values) { + for (int64 i = 0, e = result.size(); i < e; i++) { + result[i] = Add(result[i], ir_builder()->CreateExtractElement( + init_values, ir_builder()->getInt32(i))); + } + } + return result; +} + +std::vector +VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( + std::vector vectors, llvm::Value* init_values) { + while (vectors.size() != 2) { + std::vector new_vectors; + for (int i = 0; i < vectors.size(); i += 2) { + new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1])); + } + + vectors = std::move(new_vectors); + } + + llvm::Value* low = + AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0])); + if (init_values) { + low = AddInternal(ExtractLowHalf(init_values), low); + } + llvm::Value* high = + AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1])); + if (init_values) { + high = AddInternal(ExtractHighHalf(init_values), high); + } + + std::vector results; + for (int i = 0; i < 8; i++) { + llvm::Value* scalar_result = ir_builder()->CreateExtractElement( + i < 4 ? low : high, ir_builder()->getInt32(i % 4), name()); + results.push_back(scalar_result); + } + + return results; +} + +llvm::Value* VectorSupportLibrary::GetZeroVector() { + return llvm::Constant::getNullValue(vector_type()); +} + +llvm::Value* VectorSupportLibrary::GetZeroScalar() { + return llvm::Constant::getNullValue(scalar_type()); +} + +LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder) + : ir_builder_(ir_builder) { + alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_); +} + +llvm::Value* LlvmVariable::Get() const { + return ir_builder_->CreateLoad(alloca_); +} + +void LlvmVariable::Set(llvm::Value* new_value) { + ir_builder_->CreateStore(new_value, alloca_); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h new file mode 100644 index 0000000000000000000000000000000000000000..f404687ab6864bd0702d142ff691a394b78278a5 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h @@ -0,0 +1,206 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ + +#include + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +// A thin wrapper around llvm_util.h to make code generating vector math flow +// more readable. +class VectorSupportLibrary { + public: + // This VectorSupportLibrary instance remembers `primitive_type` and + // `vector_size`, and these are implicitly used by the methods on this + // instance (i.e. LoadVector will load a vector of type <`vector_size` x + // `primitive_type`>). + VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size, + llvm::IRBuilder<>* ir_builder, std::string name); + + llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Mul(int64 lhs, llvm::Value* rhs) { + return Mul(ir_builder()->getInt64(lhs), rhs); + } + + llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Add(int64 lhs, llvm::Value* rhs) { + return Add(ir_builder()->getInt64(lhs), rhs); + } + + llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) { + return Add(c, Mul(a, b)); + } + + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + llvm::Value* offset_elements); + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + int64 offset_elements) { + return ComputeOffsetPointer(base_pointer, + ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadVector(llvm::Value* pointer); + + llvm::Value* LoadVector(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements)); + } + + llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) { + return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadScalar(llvm::Value* pointer); + + llvm::Value* LoadScalar(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements)); + } + + llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) { + return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + void StoreVector(llvm::Value* value, llvm::Value* pointer); + + void StoreVector(llvm::Value* value, llvm::Value* base_pointer, + llvm::Value* offset_elements) { + StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements)); + } + + void StoreVector(llvm::Value* value, llvm::Value* base_pointer, + int64 offset_elements) { + StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements)); + } + + void StoreScalar(llvm::Value* value, llvm::Value* pointer); + void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, + llvm::Value* offset_elements) { + StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements)); + } + + void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, + int64 offset_elements) { + StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadBroadcast(llvm::Value* pointer); + llvm::Value* LoadBroadcast(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements)); + } + llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) { + return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + // Compute the horizontal sum of each vector in `vectors`. The i'th element + // in the result vector is the (scalar) horizontal sum of the i'th vector in + // `vectors`. If `init_values` is not nullptr then the value in the i'th lane + // in `init_values` is added to the i'th horizontal sum. + std::vector ComputeHorizontalSums( + std::vector vectors, llvm::Value* init_values = nullptr); + + llvm::Value* GetZeroVector(); + llvm::Value* GetZeroScalar(); + + llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } + int64 vector_size() const { return vector_size_; } + llvm::Type* vector_type() const { return vector_type_; } + llvm::Type* vector_pointer_type() const { return vector_pointer_type_; } + llvm::Type* scalar_type() const { return scalar_type_; } + llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; } + + const std::string& name() const { return name_; } + + private: + llvm::Value* ExtractLowHalf(llvm::Value*); + llvm::Value* ExtractHighHalf(llvm::Value*); + + llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs); + + llvm::Value* AddReduce(llvm::Value* vector); + + // Perform an X86 AVX style horizontal add between `lhs` and `rhs`. The + // resulting IR for an 8-float wide vector is expected to lower to a single + // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in + // other cases. + // + // For a vector width of 8, the result vector is computed as: + // Result[0] = Lhs[0] + Lhs[1] + // Result[1] = Lhs[2] + Lhs[3] + // Result[2] = Rhs[0] + Rhs[1] + // Result[3] = Rhs[2] + Rhs[3] + // Result[4] = Lhs[4] + Lhs[5] + // Result[5] = Lhs[6] + Lhs[7] + // Result[6] = Rhs[4] + Rhs[5] + // Result[7] = Rhs[6] + Rhs[7] + llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs); + + std::vector ComputeAvxOptimizedHorizontalSums( + std::vector vectors, llvm::Value* init_values); + + int64 vector_size_; + PrimitiveType primitive_type_; + llvm::IRBuilder<>* ir_builder_; + llvm::Type* vector_type_; + llvm::Type* vector_pointer_type_; + llvm::Type* scalar_type_; + llvm::Type* scalar_pointer_type_; + std::string name_; +}; + +// This wraps an alloca-backed stack variable which LLVM's SSA construction pass +// can later convert to a SSA value. +class LlvmVariable { + public: + LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder); + + llvm::Value* Get() const; + void Set(llvm::Value* new_value); + + private: + llvm::AllocaInst* alloca_; + llvm::IRBuilder<>* ir_builder_; +}; + +class VectorVariable : public LlvmVariable { + public: + VectorVariable(VectorSupportLibrary* vector_support, + llvm::Value* initial_value) + : LlvmVariable(vector_support->vector_type(), + vector_support->ir_builder()) { + Set(initial_value); + } +}; + +class ScalarVariable : public LlvmVariable { + public: + ScalarVariable(VectorSupportLibrary* vector_support, + llvm::Value* initial_value) + : LlvmVariable(vector_support->scalar_type(), + vector_support->ir_builder()) { + Set(initial_value); + } +}; +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index d4d35da9d636e6e204f36850e7987327ab258696..06f43bd3cb2376d34a3104133c868c4f4e5cc730 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -68,26 +68,6 @@ LocalService::LocalService(const ServiceOptions& options, std::unique_ptr execute_backend) : Service(options, std::move(execute_backend)) {} -namespace { -// Returns the space required to allocate a shape. If -// allocate_space_for_deep_copy the space includes all sub-buffers of -// a tuple. -int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy, - TransferManager* transfer_manager) { - int64 size = 0; - // TODO(b/33492279) remove once no devices represent result tuples as - // contiguous buffers. - if (allocate_space_for_deep_copy) { - ShapeUtil::ForEachSubshape( - shape, [&size, transfer_manager](const Shape& subshape, - const ShapeIndex& /*index*/) { - size += transfer_manager->GetByteSizeRequirement(subshape); - }); - } - return size; -} -} // namespace - StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index b92017c6cbc43d78ab4e5b32f25f5980b8d4ae56..6aca6ba38572c5311797fbb91acbbcd6610a3410 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -23,6 +23,23 @@ limitations under the License. namespace xla { +namespace { + +// Gather fusion instructions from 'instruction' into 'fusion_instructions'. +void GatherFusionInstructions( + HloInstruction* instruction, + std::vector* fusion_instructions) { + CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); + for (auto* fused : instruction->fused_instructions()) { + if (fused->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(fused, fusion_instructions); + } + } + fusion_instructions->push_back(instruction); +} + +} // namespace + /* static */ StatusOr> LogicalBufferAnalysis::Run(const HloModule* module) { std::unique_ptr analysis( @@ -41,15 +58,19 @@ Status LogicalBufferAnalysis::Analyze() { // We filter out fusion computations, and get to them through fusion // instructions. This is because it's possible to have orphaned (unreachable) // fusion computations, and we don't want to try to assign buffers to those. + std::vector fusion_instructions; for (auto* computation : module_->MakeNonfusionComputations()) { TF_RETURN_IF_ERROR(computation->Accept(this)); for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kFusion) { continue; } - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + GatherFusionInstructions(instruction, &fusion_instructions); } } + for (auto* instruction : fusion_instructions) { + TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + } return Status::OK(); } @@ -104,6 +125,21 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { + // RecvDone doesn't create a new buffer but rather aliases its input (Recv) + // tuple element at {0} to its output. + return Status::OK(); +} + +Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) { + // Send creates new buffers for the top-level tuple and the context (tuple + // element at {1}). Tuple element at {0} is an alias of the Send operand, so + // we don't need to create a new Logical Buffer for that. + NewLogicalBuffer(send, /*index=*/{}); + NewLogicalBuffer(send, /*index=*/{1}); + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) { // A Tuple instruction only creates the top-level buffer. NewLogicalBuffer(tuple, /*index=*/{}); diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index a82e83ec5c3d2b0e011d85f3d03bea8fca870154..598d08b7203b25b194dfc3b3125ec58c96b2cd4c 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -60,6 +60,8 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleCopy(HloInstruction* copy) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleSend(HloInstruction* send) override; Status HandleSelect(HloInstruction* select) override; // A map from the buffer ID to the logical buffer diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index a0d08c288dbcc45e83a36ce7b094b04a9dbae532..7d8c05fffa4ab11d7dbf9956d2cb7ebd5bcdd3c4 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -17,12 +17,44 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + +bool IsAllowed(char character) { + auto c = static_cast(character); + return (isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; +} + +} // namespace + +NameUniquer::NameUniquer(const string& separator) { + CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed)) + << "separator should comprises allowed characters only"; + separator_ = separator; +} + +/*static*/ string NameUniquer::GetSanitizedName(const string& name) { + string result = name; + CHECK(!result.empty()) << "name should not be empty"; + char c = static_cast(result[0]); + if (!isalpha(c) && c != '_') { + result[0] = '_'; + } + for (int i = 1; i < result.length(); i++) { + if (!IsAllowed(result[i])) { + result[i] = '_'; + } + } + return result; +} + string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { string root = prefix.empty() ? "name" : prefix.ToString(); + root = GetSanitizedName(root); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index ed379b52258463b960dea788721c2c4325ef0260..4139c2700b25e8600182a034a8ac6f4f041c12e6 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -28,14 +28,21 @@ namespace xla { // Simple stateful class that helps generate "unique" names. To use it, simply // call GetUniqueName as many times as needed. The names returned by // GetUniqueName are guaranteed to be distinct for this instance of the class. +// Note that the names will be sanitized to match regexp +// "[a-zA-Z_][a-zA-Z0-9_.-]*". class NameUniquer { public: - explicit NameUniquer(const string& separator = "__") - : separator_(separator) {} + // The separator must contain allowed characters only: "[a-zA-Z0-9_.-]". + explicit NameUniquer(const string& separator = "__"); - // Get a unique name in a string, with an optional prefix for convenience. + // Get a sanitized unique name in a string, with an optional prefix for + // convenience. string GetUniqueName(tensorflow::StringPiece prefix = ""); + // Sanitizes and returns the name. Unallowed characters will be replaced with + // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + static string GetSanitizedName(const string& name); + private: // The string to use to separate the prefix of the name from the uniquing // integer value. diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 9f0747a6e2175a968d8f3661ac51512009e86f29..4258cf16876ab46dce6df062ab701b1b1a4a7580 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -60,12 +60,30 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000")); EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); +} + +TEST_F(NameUniquerTest, Sanitize) { + NameUniquer uniquer("_"); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo_1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); + EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54")); + EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1")); + EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); + + // Invalid characters will be replaced with '_'. + EXPECT_EQ("bar", uniquer.GetUniqueName("bar<-1000")); + EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); + EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); // Separator is only recognized in the middle of the prefix. - EXPECT_EQ(".10", uniquer.GetUniqueName(".10")); - EXPECT_EQ(".10.1", uniquer.GetUniqueName(".10")); - EXPECT_EQ("foobar.", uniquer.GetUniqueName("foobar.")); - EXPECT_EQ("foobar..1", uniquer.GetUniqueName("foobar.")); + EXPECT_EQ("_10", uniquer.GetUniqueName( + ".10")); // the leading '.' is replaced with '_'. + EXPECT_EQ("_10_1", uniquer.GetUniqueName(".10")); + EXPECT_EQ("_10_2", uniquer.GetUniqueName("_10")); + EXPECT_EQ("foobar_", uniquer.GetUniqueName("foobar_")); + EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_")); } } // namespace diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 3a1818de82d3fd305e2c6b3bd1f2cf8125806a75..aa974ee61a27de9c19e97d8a6eb48f9261ce4bd9 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -33,10 +33,32 @@ namespace se = ::perftools::gputools; namespace xla { +using tensorflow::str_util::Lowercase; + // Minimum supported CUDA compute capability is 3.5. constexpr int kMinCudaComputeCapabilityMajor = 3; constexpr int kMinCudaComputeCapabilityMinor = 5; +// The name of the interpreter platform. +constexpr char kInterpreter[] = "interpreter"; + +namespace { + +string CanonicalPlatformName(const string& name) { + string platform_str = Lowercase(name); + // "cpu" and "host" mean the same thing. + if (platform_str == "cpu") { + platform_str = "host"; + } + // "gpu" and "cuda" mean the same thing. + if (platform_str == "gpu") { + platform_str = "cuda"; + } + return platform_str; +} + +} // namespace + /* static */ StatusOr> PlatformUtil::GetSupportedPlatforms() { se::MultiPlatformManager::PlatformMap platform_map; @@ -78,7 +100,7 @@ PlatformUtil::GetSupportedPlatforms() { return platforms; } -/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { +/* static */ StatusOr PlatformUtil::GetSolePlatform() { TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); if (platforms.empty()) { return NotFound("no platforms found"); @@ -87,13 +109,77 @@ PlatformUtil::GetSupportedPlatforms() { } // Multiple platforms present and we can't pick a reasonable default. - auto l = [](string* out, const se::Platform* p) { out->append(p->Name()); }; - string platforms_string = tensorflow::str_util::Join(platforms, ", ", l); + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform found: %s", platforms_string.c_str()); } +/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { + TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); + if (platforms.empty()) { + return NotFound("no platforms found"); + } else if (platforms.size() == 1) { + return platforms[0]; + } else if (platforms.size() == 2) { + for (int i = 0; i < 2; i++) { + if (Lowercase(platforms[i]->Name()) == kInterpreter && + Lowercase(platforms[1 - i]->Name()) != kInterpreter) { + return platforms[1 - i]; + } + } + } + + // Multiple platforms present and we can't pick a reasonable default. + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "must specify platform because more than one platform (except for the " + "interpreter platform) found: %s", + platforms_string.c_str()); +} + +/*static*/ StatusOr PlatformUtil::GetPlatform( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); + TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); + for (se::Platform* platform : platforms) { + if (Lowercase(platform->Name()) == platform_str) { + return platform; + } + } + return InvalidArgument("platform %s not found", platform_name.c_str()); +} + +/*static*/ StatusOr PlatformUtil::GetPlatformExceptFor( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); + + TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); + std::vector matched; + for (se::Platform* platform : platforms) { + if (Lowercase(platform->Name()) != platform_name) { + matched.push_back(platform); + } + } + if (matched.empty()) { + return InvalidArgument("unable to find platform that is not %s", + platform_name.c_str()); + } + if (matched.size() == 1) { + return matched[0]; + } + string matched_string = tensorflow::str_util::Join( + matched, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "found multiple platforms %s, but expected one platform except for %s", + matched_string.c_str(), platform_name.c_str()); +} + // Returns whether the device underlying the given StreamExecutor is supported // by XLA. static bool IsDeviceSupported(se::StreamExecutor* executor) { diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index eac573703085aca2801885cd9abbe0022f1c029e..69188820a70707d9c9be10b20fb7de92ad4d9873 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -16,11 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ +#include #include #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -34,10 +37,27 @@ class PlatformUtil { static StatusOr> GetSupportedPlatforms(); - // Convenience function which returns the default supported platform. If + // Convenience function which returns the default supported platform for + // tests. If exactly one supported platform is present, then this platform is + // the default platform. If exactly two platforms are present and one of them + // is the interpreter platform, then the other platform is the default + // platform. Otherwise returns an error. + static StatusOr GetDefaultPlatform(); + + // Convenience function which returns the sole supported platform. If // exactly one supported platform is present, then this platform is the // default platform. Otherwise returns an error. - static StatusOr GetDefaultPlatform(); + static StatusOr GetSolePlatform(); + + // Returns the platform according to the given name. Returns error if there is + // no such platform. + static StatusOr GetPlatform( + const string& platform_name); + + // Returns exactly one platform that does not have given name. Returns error + // if there is no such platform, or there are multiple such platforms. + static StatusOr GetPlatformExceptFor( + const string& platform_name); // Returns a vector of StreamExecutors for the given platform. The vector is // indexed by device ordinal (device numbering used by StreamExecutor). If an diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 71afbee456b0f5eb67cb092d84f8e95ea1038c54..ecc3c0ff12718592bd8e8847eb5ef806d2b60821 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -430,9 +430,12 @@ StatusOr> Service::BuildExecutable( /*include_unreachable_instructions=*/ true)); + TF_ASSIGN_OR_RETURN( + module, backend->compiler()->RunHloPasses(std::move(module), executor)); + TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend->compiler()->Compile(std::move(module), executor)); + backend->compiler()->RunBackend(std::move(module), executor)); if (!other_directory_path.empty()) { executable->set_session_module(std::move(session_module)); @@ -563,8 +566,10 @@ Service::ExecuteParallelAndRegisterResult( // Wait for all executions to complete. for (int64 i = 0; i < streams.size(); ++i) { - if (!streams[i]->BlockHostUntilDone()) { - return InternalError("failed to complete execution for stream %lld", i); + Status block_status = streams[i]->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("failed to complete execution for stream %lld: %s", + i, block_status.error_message().c_str()); } } @@ -573,29 +578,15 @@ Service::ExecuteParallelAndRegisterResult( for (auto& index_to_profiled_stream : index_to_profiled_streams) { int64 device = index_to_profiled_stream.first; se::Stream* stream = index_to_profiled_stream.second; - HloExecutionProfile hlo_profile; - TF_RETURN_IF_ERROR(executables[device]->PopulateExecutionProfile( - &hlo_profile, stream->parent())); - - std::unordered_set profiled_computations = - hlo_profile.profiled_computations(); - // To ensure we have print the profiles in a stable order, iterate over the - // computations in post order. - auto& module = executables[device]->module(); - std::list all_computations = - module.MakeComputationPostOrder(); - for (xla::HloComputation* computation : all_computations) { - if (profiled_computations.count(computation) > 0) { - string profile_string = hlo_profile.ToString( - *computation, streams[0]->parent()->GetDeviceDescription(), - executables[device]->CreateCostAnalysis().get()); - if (!profile_string.empty()) { - LOG(INFO) << "HLO profile for execution on device " << device - << ":\n"; - XLA_LOG_LINES(tensorflow::INFO, profile_string); - } - } - } + Executable* executable = executables[device]; + const HloModule& module = executable->module(); + HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(), + &executable->hlo_profile_index_map()); + TF_RETURN_IF_ERROR( + executable->PopulateExecutionProfile(&hlo_profile, stream->parent())); + XLA_LOG_LINES( + tensorflow::INFO, + hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription())); hlo_graph_dumper::MaybeDumpHloModule(module, "Service::Execute", &hlo_profile); } @@ -677,6 +668,7 @@ StatusOr Service::ExecuteAndRegisterResult( result, executable->ExecuteOnStreamWrapper( &run_options[0], profile, arguments)); } else { + // TODO(b/69985541): Support profiling also on this path. std::vector< tensorflow::gtl::ArraySlice> repeated_arguments(options_.number_of_replicas(), arguments); @@ -1053,18 +1045,29 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, return tensorflow::Status::OK(); } +namespace { + +// Creates a clone of the given shaped buffer with the given device ordinal. The +// shape and DeviceMemoryBase values of the clone are identical to the original. +std::unique_ptr CloneShapedBufferOnDevice( + const ShapedBuffer& shaped_buffer, int device_ordinal) { + auto clone = MakeUnique( + shaped_buffer.shape(), shaped_buffer.platform(), device_ordinal); + ShapeUtil::ForEachSubshape( + shaped_buffer.shape(), [&clone, &shaped_buffer](const Shape& /*subshape*/, + const ShapeIndex& index) { + clone->AddBufferAtIndex(shaped_buffer.buffer(index), index); + }); + return clone; +} + +} // namespace + tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { Literal literal = Literal(arg->literal()); const Shape& shape = literal.shape(); - if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) { - // TODO(b/32990684): Tuple transfers to host end up allocating further - // buffers - implement that correctly. - return Unimplemented( - "Tuple transfers to the device not supported with replication."); - } - std::vector replicas; if (arg->has_device_handle()) { TF_ASSIGN_OR_RETURN(replicas, @@ -1074,24 +1077,45 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } - // Allocate memory on the device, using the stream executor. The size of the - // allocation is obtained by examining the shape of the literal passed from - // the client. An allocation handle is returned in the response. - int64 allocation_size = - execute_backend_->transfer_manager()->GetByteSizeRequirement(shape); - - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, - execute_backend_->memory_allocator()->Allocate( - replicas[0]->device_ordinal(), allocation_size)); - + // All memory allocation is done on the first replica. The allocations in all + // other replicas mirror the firsts'. + int master_device_ordinal = replicas[0]->device_ordinal(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr shaped_buffer, + ShapedBuffer::Allocate( + execute_backend_->transfer_manager()->HostShapeToDeviceShape(shape), + execute_backend_->memory_allocator(), master_device_ordinal, + [this](const Shape& shape) { + return execute_backend_->transfer_manager()->GetByteSizeRequirement( + shape); + })); + + // The allocation tracker only keeps track of the top-level buffer of the + // shape so pass in the buffer at shape index {}. + // TODO(b/37515654): Allocation tracker should hold a ShapedBuffer. *result->mutable_data() = allocation_tracker_.Register( - execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape, - StrCat("TransferToServer literal of size ", allocation_size)); + execute_backend_.get(), master_device_ordinal, + shaped_buffer->buffer(/*index=*/{}), shape, + StrCat("TransferToServer literal of shape ", + ShapeUtil::HumanString(shape))); + // Transfer the data to the replicas. for (se::StreamExecutor* executor : replicas) { - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, literal, &allocation)); + if (executor->device_ordinal() == master_device_ordinal) { + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, literal, *shaped_buffer)); + } else { + // The replica is not the master. Create an cloned shaped buffer with + // the replica's device ordinal. This is required because + // TransferLiteralToDevice verifies that the device ordinal of the shaped + // buffer matches that of the executor. + std::unique_ptr clone = + CloneShapedBufferOnDevice(*shaped_buffer, executor->device_ordinal()); + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, literal, *clone)); + } } return tensorflow::Status::OK(); } @@ -1368,6 +1392,17 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddConcatenateInstruction(arg->concatenate_request()); break; + case OpRequest::kConditionalRequest: { + TF_ASSIGN_OR_RETURN(UserComputation * true_computation, + computation_tracker_.Resolve( + arg->conditional_request().true_computation())); + TF_ASSIGN_OR_RETURN(UserComputation * false_computation, + computation_tracker_.Resolve( + arg->conditional_request().false_computation())); + handle_status = computation->AddConditionalInstruction( + arg->conditional_request(), *true_computation, *false_computation); + break; + } case OpRequest::kConstantRequest: handle_status = computation->AddConstantInstruction(arg->constant_request()); @@ -1376,6 +1411,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddConvertInstruction(arg->convert_request()); break; + case OpRequest::kBitcastConvertRequest: + handle_status = computation->AddBitcastConvertInstruction( + arg->bitcast_convert_request()); + break; case OpRequest::kConvolveRequest: handle_status = computation->AddConvolveInstruction(arg->convolve_request()); @@ -1388,6 +1427,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddCustomCallInstruction(arg->custom_call_request()); break; + case OpRequest::kDotRequest: + handle_status = computation->AddDotInstruction(arg->dot_request()); + break; case OpRequest::kDynamicSliceRequest: handle_status = computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); @@ -1508,8 +1550,12 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddRecvInstruction(arg->recv_request()); break; } + case OpRequest::kFftRequest: + return Unimplemented("FftRequest not implemented in XLA service."); + case OpRequest::OP_NOT_SET: + return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); default: - return InvalidArgument("Unsupported operation"); + return InvalidArgument("Unsupported operation in XLA service"); } TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 6646be2e9aa43763b93bcea7a1df9d10580f162c..47f4f0ade594089aa71717ef1e122886b0a6c7ac 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -272,8 +272,6 @@ class Service : public ServiceInterface { // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. - // has_hybrid_result is used to initialize the same-named field in - // HloModuleConfig -- see that class for documentation. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 791d17365b1d756714b5feb0439e6919d9f23edc..9c1b951d017569a6dc89bc6583c72b5e42f0c07c 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -29,8 +29,10 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -89,8 +91,6 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_ATAN2; case HloOpcode::kComplex: return BINOP_COMPLEX; - case HloOpcode::kDot: - return BINOP_DOT; case HloOpcode::kMultiply: return BINOP_MUL; case HloOpcode::kAdd: @@ -440,6 +440,37 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { + auto old_element_type = operand_shape.element_type(); + if (primitive_util::IsComplexType(old_element_type) && + !primitive_util::IsComplexType(new_element_type)) { + return Unimplemented( + "Unsupported conversion from complex to real type: %s => %s", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } + if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + // Note: we may want to support tuple conversions via this operation in the + // future, by recursing into the tuple elements to check all sub-conversions + // are valid. For now we just reject them, though. + return InvalidArgument( + "cannot convert from or to tuple type; requested conversion: %s => %s", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } + + return ShapeUtil::ChangeElementType(operand_shape, new_element_type); +} + +/* static */ StatusOr ShapeInference::InferBitcastConvertShape( + const Shape& operand_shape, PrimitiveType new_element_type) { + auto old_element_type = operand_shape.element_type(); + if (primitive_util::IsComplexType(old_element_type) != + primitive_util::IsComplexType(new_element_type)) { + return Unimplemented( + "Unsupported conversion between real and complex types: %s => %s", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions @@ -449,6 +480,13 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } + if (primitive_util::BitWidth(old_element_type) != + primitive_util::BitWidth(new_element_type)) { + return InvalidArgument( + "cannot bitcast types with different bit-widths: %s => %s", + PrimitiveType_Name(old_element_type).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); } @@ -510,8 +548,113 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions); } -/* static */ StatusOr ShapeInference::InferDotOpShape(const Shape& lhs, - const Shape& rhs) { +// Current DotDimensionNumbers Requirements: +// +// Contracting Dimensions: +// *) Exactly one contracting dimension on both lhs and rhs. +// *) Contracting dimension size must be the same on both lhs and rhs. +// *) Contracting dimension numbers do not need to be the same (i.e. transposes +// are passed on to emitter implementations). +// +// Batch Dimensions: +// *) Same number of batch dimensions on both lhs and rhs. +// *) Same batch dimension numbers (and sizes) on both lhs and rhs. +// *) Batch dimension numbers must be ordered before contracting and +// non-contracting/non-batch dimension numbers. +// +// Non-Contracting-Non-Batch Dimensions: +// *) Can be 0 (matrix-vector) or 1 (matrix-matrix). +// + +namespace { + +Status ValidateDotDimensionNumbers( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { + // Check that dimension numbers are in range. + auto dims_in_range = + [](const int64 rank, tensorflow::gtl::ArraySlice contracting_dims, + tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + in_range) && + std::all_of(batch_dims.begin(), batch_dims.end(), in_range); + }; + + tensorflow::gtl::ArraySlice lhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice rhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice lhs_batch_dimensions = + AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); + tensorflow::gtl::ArraySlice rhs_batch_dimensions = + AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); + + if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, + lhs_batch_dimensions) || + !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, + rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is out of range in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that dimension numbers are unique. + auto dims_unique = [](tensorflow::gtl::ArraySlice contracting_dims, + tensorflow::gtl::ArraySlice batch_dims) -> bool { + tensorflow::gtl::FlatSet dim_set; + auto is_unique = [&dim_set](int64 i) -> bool { + return dim_set.insert(i).second; + }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + is_unique) && + std::all_of(batch_dims.begin(), batch_dims.end(), is_unique); + }; + + if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || + !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is not unique in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. + const int64 lhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(lhs) - + dimension_numbers.lhs_contracting_dimensions_size() - + dimension_numbers.lhs_batch_dimensions_size(); + const int64 rhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(rhs) - + dimension_numbers.rhs_contracting_dimensions_size() - + dimension_numbers.rhs_batch_dimensions_size(); + if (lhs_non_contracting_non_batch_dims < 0 || + lhs_non_contracting_non_batch_dims > 1 || + rhs_non_contracting_non_batch_dims < 0 || + rhs_non_contracting_non_batch_dims > 1) { + return InvalidArgument( + "batch and contracting dimension number mismatch " + "with rank "); + } + + // Check that batch dimension numbers are ordered before all others, and + // that they are monotonically increasing. + std::vector batch_dim_numbers(lhs_batch_dimensions.size()); + std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0); + if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), + lhs_batch_dimensions.begin()) || + !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), + rhs_batch_dimensions.begin())) { + return InvalidArgument( + "batch dimension numbers must precede non-batch dimensions and be" + "monotonically increasing."); + } + + return Status::OK(); +} + +} // namespace + +/* static */ StatusOr ShapeInference::InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot")); @@ -531,37 +674,62 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return fail("element types do not match"); } - if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 || - ShapeUtil::Rank(rhs) < 1 || ShapeUtil::Rank(rhs) > 2) { - return fail("dot only supports rank 1 or 2"); + if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { + return fail("dot only supports rank 1 or above."); + } + + // Validate basic properties of dot dimension numbers. + TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); + + // Check that there is only one contracting dimension for both lhs and rhs. + if (dimension_numbers.lhs_contracting_dimensions_size() != + dimension_numbers.rhs_contracting_dimensions_size() || + dimension_numbers.lhs_contracting_dimensions_size() != 1) { + return fail("must specify one contracting dimension for both lhs and rhs."); + } + + // Check that contracting dimension sizes match. + const int64 lhs_contracting_dimension = + dimension_numbers.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = + dimension_numbers.rhs_contracting_dimensions(0); + if (lhs.dimensions(lhs_contracting_dimension) != + rhs.dimensions(rhs_contracting_dimension)) { + return fail("contracting dimension sizes do not match."); } - // Determine the index of the contracted dimensions for input tensors. - // dimensions -1 of lhs and dimension 0 of rhs are contracted. - int64 lhs_contracted_dimension = ShapeUtil::GetDimensionNumber(lhs, -1); - int64 rhs_contracted_dimension = 0; + // Check that number of batch dimensions match. + if (dimension_numbers.lhs_batch_dimensions_size() != + dimension_numbers.rhs_batch_dimensions_size()) { + return fail("must the same number of batch dimensions for lhs and rhs."); + } - // Check if the contracted dimension sizes are the same. - if ((lhs_contracted_dimension < ShapeUtil::Rank(lhs) && - rhs_contracted_dimension < ShapeUtil::Rank(rhs)) && - lhs.dimensions(lhs_contracted_dimension) != - rhs.dimensions(rhs_contracted_dimension)) { - return fail("contracted dimensions mismatch"); + // Check that batch dimension numbers and sizes match. + for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { + if (dimension_numbers.lhs_batch_dimensions(i) != + dimension_numbers.rhs_batch_dimensions(i) || + lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { + return fail("batch dimension numbers and sizes must match for lhs/rhs."); + } } // The ranks of lhs and rhs are decremented by 1 respectively due to the // contraction, and added for the rank of the result. When an input tensor is // a scalar, its contribution to the rank of the result is 0. // Generate the result dimensions in order, rhs dimensions followed by lhs - // dimensions except the contracted dimensions. + // dimensions except the contracted and batch dimensions. std::vector dimensions; + std::unordered_set rhs_batch_dims( + dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) { - if (i != lhs_contracted_dimension) { + if (i != lhs_contracting_dimension) { dimensions.push_back(lhs.dimensions(i)); } } for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) { - if (i != rhs_contracted_dimension) { + if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) { dimensions.push_back(rhs.dimensions(i)); } } @@ -770,11 +938,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of binary operation")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of binary operation")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + lhs, tensorflow::strings::StrCat("lhs of binary operation ", + BinaryOperation_Name(operation)))); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + rhs, tensorflow::strings::StrCat("rhs of binary operation ", + BinaryOperation_Name(operation)))); switch (operation) { - case BINOP_DOT: - return InferDotOpShape(lhs, rhs); case BINOP_MAX: case BINOP_MIN: case BINOP_SUB: @@ -1402,7 +1572,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } - if (dnums.spatial_dimensions_size() != + if (dnums.input_spatial_dimensions_size() != dnums.kernel_spatial_dimensions_size()) { return InvalidArgument( "Both arguments to convolution must have same number of dimensions.\n" @@ -1410,7 +1580,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( window.DebugString().c_str()); } - const int num_spatial_dims = dnums.spatial_dimensions_size(); + const int num_spatial_dims = dnums.input_spatial_dimensions_size(); if (window.dimensions_size() != num_spatial_dims) { return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" @@ -1439,8 +1609,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( std::vector input_dnums(num_dims); input_dnums[0] = dnums.input_batch_dimension(); input_dnums[1] = dnums.input_feature_dimension(); - std::copy(dnums.spatial_dimensions().begin(), - dnums.spatial_dimensions().end(), input_dnums.begin() + 2); + std::copy(dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2); std::sort(input_dnums.begin(), input_dnums.end()); std::vector window_dnums(num_dims); @@ -1450,12 +1620,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2); std::sort(window_dnums.begin(), window_dnums.end()); + std::vector output_dnums(num_dims); + output_dnums[0] = dnums.output_batch_dimension(); + output_dnums[1] = dnums.output_feature_dimension(); + std::copy(dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2); + std::sort(output_dnums.begin(), output_dnums.end()); + std::vector expected_dnums(num_dims); std::iota(expected_dnums.begin(), expected_dnums.end(), 0); const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; }; if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) || - !std::all_of(window_dnums.begin(), window_dnums.end(), in_range)) { + !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) || + !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s", dnums.DebugString().c_str()); @@ -1473,10 +1651,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "once: %s", dnums.DebugString().c_str()); } + if (output_dnums != expected_dnums) { + return InvalidArgument( + "Output dimensions of convolution must contain each dimension exactly " + "once: %s", + dnums.DebugString().c_str()); + } std::vector input_spatial_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { - input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i)); + input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i)); } const int64 input_features = lhs.dimensions(dnums.input_feature_dimension()); const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension()); @@ -1524,17 +1708,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( dimensions[dnums.output_batch_dimension()] = input_batch; dimensions[dnums.output_feature_dimension()] = kernel_output_features; for (int i = 0; i < num_spatial_dims; ++i) { - dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i); + dimensions[dnums.output_spatial_dimensions(i)] = + window_output_shape.dimensions(i); } return ShapeUtil::MakeShape(lhs.element_type(), dimensions); } /* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( - const Shape& operand) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand, "operand of cross replica sum")); - return operand; + tensorflow::gtl::ArraySlice operand_shapes) { + for (const Shape* operand_shape : operand_shapes) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum")); + } + if (operand_shapes.size() == 1) { + return *operand_shapes[0]; + } + std::vector operand_shape_values; + for (const Shape* operand_shape : operand_shapes) { + operand_shape_values.push_back(*operand_shape); + } + return ShapeUtil::MakeTupleShape(operand_shape_values); } /* static */ StatusOr ShapeInference::InferReduceShape( @@ -1900,6 +2094,64 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return init; } +/* static */ StatusOr ShapeInference::InferConditionalShape( + const Shape& predicate, const Shape& true_operand, + const Shape& false_operand, const ProgramShape& true_computation, + const ProgramShape& false_computation) { + if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { + return InvalidArgument("predicate must be a boolean; got %s.", + ShapeUtil::HumanString(predicate).c_str()); + } + + if (true_computation.parameters_size() != 1) { + return InvalidArgument("true_computation must take 1 argument; got %d.", + true_computation.parameters_size()); + } + if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { + auto true_shape_string = [&]() { + return tensorflow::strings::Printf( + "true_operand: %s; true_computation: %s", + ShapeUtil::HumanString(true_operand).c_str(), + ShapeUtil::HumanString(true_computation).c_str()); + }; + return InvalidArgument( + "true_operand must match the shape of the only parameter of " + "true_computation: got %s.", + true_shape_string().c_str()); + } + + if (false_computation.parameters_size() != 1) { + return InvalidArgument("false_computation must take 1 argument; got %d.", + false_computation.parameters_size()); + } + if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { + auto false_shape_string = [&]() { + return tensorflow::strings::Printf( + "false_operand: %s; false_computation: %s", + ShapeUtil::HumanString(false_operand).c_str(), + ShapeUtil::HumanString(false_computation).c_str()); + }; + return InvalidArgument( + "false_operand must match the shape of the only parameter of " + "false_computation: got %s.", + false_shape_string().c_str()); + } + if (!ShapeUtil::Compatible(true_computation.result(), + false_computation.result())) { + auto shape_string = [&]() { + return tensorflow::strings::Printf( + "true_computation result: %s; false_computation result: %s.", + ShapeUtil::HumanString(true_computation.result()).c_str(), + ShapeUtil::HumanString(false_computation.result()).c_str()); + }; + return InvalidArgument( + "the result of true_computation and false_computation must have the " + "same shape: got %s.", + shape_string().c_str()); + } + return true_computation.result(); +} + /* static */ StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast")); @@ -1943,7 +2195,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( - "Reshape dimensions not a permutation of the operand dimensions."); + "Reshape dimensions [%s] are not a permutation of the operand " + "dimensions (operand shape is %s).", + tensorflow::str_util::Join(dimensions, ",").c_str(), + ShapeUtil::HumanString(operand).c_str()); } return inferred_shape; diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index d5d497176d6c340d8c8f34cdacf6a9e32040c387..c06340d2d5df239642eb0af4836df64a898a1eaf 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -109,8 +109,10 @@ class ShapeInference { const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); - // Infers the shape produced a cross replica sum with the given operand shape. - static StatusOr InferCrossReplicaSumShape(const Shape& operand); + // Infers the shape produced a cross replica sum with the given operand + // shapes. + static StatusOr InferCrossReplicaSumShape( + tensorflow::gtl::ArraySlice operand_shapes); // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. @@ -178,6 +180,12 @@ class ShapeInference { const ProgramShape& body, const Shape& init); + // Infers the shape produced by a conditional operation. + static StatusOr InferConditionalShape( + const Shape& predicate, const Shape& true_operand, + const Shape& false_operand, const ProgramShape& true_computation, + const ProgramShape& false_computation); + // Infers the shape produced by a broadcast operation. static StatusOr InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes); @@ -204,6 +212,13 @@ class ShapeInference { static StatusOr InferConvertShape(const Shape& operand_shape, PrimitiveType new_element_type); + // Helper that validates the given operand shape can be bitcast converted to + // the target output_shape via a bitcast convert instruction -- the + // requirement is that the shape is identical except for the element type and + // the element types have identical bit-widths. + static StatusOr InferBitcastConvertShape( + const Shape& operand_shape, PrimitiveType new_element_type); + // Helper that validates the input data type for a reduce-precision operation, // and returns the result shape. static StatusOr InferReducePrecisionShape(const Shape& operand_shape, @@ -222,11 +237,13 @@ class ShapeInference { tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply); - private: // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. - static StatusOr InferDotOpShape(const Shape& lhs, const Shape& rhs); + static StatusOr InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. // Note: By "element-wise" we mean operations that look at a single element in diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index d12f7bd1453890db3280e54719a6ce811006336d..99d87f3b550ae72befe254f23fad080dd210aaf4 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -395,8 +395,10 @@ TEST_F(ShapeInferenceTest, Convolve) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); @@ -437,8 +439,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); @@ -480,8 +484,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4}); @@ -524,8 +530,10 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dnums.set_output_batch_dimension(3); dnums.set_input_feature_dimension(2); dnums.set_output_feature_dimension(2); - dnums.add_spatial_dimensions(0); - dnums.add_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(0); + dnums.add_output_spatial_dimensions(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0 dnums.set_kernel_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); @@ -890,8 +898,11 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) { // scalar vector: error TEST_F(ShapeInferenceTest, ScalarDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); + ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("dot only supports rank")); @@ -899,61 +910,199 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { // 3D 2D: error TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("dot only supports rank")); + HasSubstr("batch and contracting dimension number mismatch")); } // vector vector -> scalar TEST_F(ShapeInferenceTest, VectorDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_64_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); auto inferred_status_mismatch = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_32_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix vector -> vector TEST_F(ShapeInferenceTest, MatrixDotVector) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_32_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // vector matrix -> vector TEST_F(ShapeInferenceTest, VectorDotMatrix) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_32_, matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix matrix -> matrix TEST_F(ShapeInferenceTest, MatrixDotMatrix) { - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_64_48_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status_match = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE( ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_)) << "inferred: " << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) << " expected: " << ShapeUtil::HumanString(matrix_64_48_); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } +// BatchMatMul with two batch dimensions and one contracting dimension. +TEST_F(ShapeInferenceTest, DotGeneral) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + + dot_dnums.add_rhs_contracting_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status_match = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE( + ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape)) + << "inferred: " + << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) + << " expected: " << ShapeUtil::HumanString(output_shape); +} + +// BatchMatMul with two contracting dimensions fails. +TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("must specify one contracting dimension for both " + "lhs and rhs")); +} + +// BatchMatMul with different batch dimension sizes fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers and sizes must match")); +} + +// BatchMatMul with different batch dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers must precede non-batch")); +} + +// BatchMatMul with out-of-range dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is out of range")); +} + +// BatchMatMul with non-unique dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is not unique")); +} + TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { // Test variations of broadcasting a vector for a binary add with a // matrix. @@ -1288,5 +1437,80 @@ TEST_F(ShapeInferenceTest, Transpose) { ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); } +TEST_F(ShapeInferenceTest, Conditional) { + auto inferred_status0 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_IS_OK(inferred_status0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); + + auto inferred_status1 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, vector_32_, + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), + ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)); + EXPECT_IS_OK(inferred_status1.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); + + auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); + auto inferred_status2 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, tuple_f32_v32, + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)); + EXPECT_IS_OK(inferred_status2.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); + + auto inferred_status_error0 = ShapeInference::InferConditionalShape( + s32_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_FALSE(inferred_status_error0.ok()); + EXPECT_THAT(inferred_status_error0.status().error_message(), + HasSubstr("predicate must be a boolean")); + + auto inferred_status_error1 = ShapeInference::InferConditionalShape( + pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)); + EXPECT_FALSE(inferred_status_error1.ok()); + EXPECT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("true_computation must take 1 argument")); + + auto inferred_status_error2 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_64_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_FALSE(inferred_status_error2.ok()); + EXPECT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("true_operand must match the shape of the only " + "parameter of true_computation")); + + auto inferred_status_error3 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)); + EXPECT_FALSE(inferred_status_error3.ok()); + EXPECT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("false_computation must take 1 argument")); + + auto inferred_status_error4 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_)); + EXPECT_FALSE(inferred_status_error4.ok()); + EXPECT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("false_operand must match the shape of the only " + "parameter of false_computation")); + + auto inferred_status_error5 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)); + EXPECT_FALSE(inferred_status_error5.ok()); + EXPECT_THAT(inferred_status_error5.status().error_message(), + HasSubstr("the result of true_computation and false_computation " + "must have the same shape")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index a2a442eb1a33d976a114f68d112a7d8f3b540f4b..aa0a24a2833ec0b152f32f26f32e57ec6f7b5d14 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -21,17 +21,19 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace se = ::perftools::gputools; namespace xla { +using ::tensorflow::strings::Appendf; + /* static */ StatusOr> ShapedBuffer::MakeArrayShapedBuffer(const Shape& shape, const se::Platform* platform, @@ -49,6 +51,34 @@ ShapedBuffer::MakeArrayShapedBuffer(const Shape& shape, return std::move(shaped_buffer); } +/* static */ StatusOr> ShapedBuffer::Allocate( + const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, + const std::function& shape_size_fn) { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(shape).c_str()); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); + auto shaped_buffer = WrapUnique( + new ShapedBuffer(shape, allocator->platform(), device_ordinal)); + + // Allocate an appropriate sized buffer for each element in the shape + // including the tuple pointer arrays. + for (auto& pair : shaped_buffer->shape_index_to_buffer_entry_) { + const ShapeIndex& index = pair.first; + size_t& buffer_entry = pair.second; + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase memory_base, + allocator->Allocate(shaped_buffer->device_ordinal(), + shape_size_fn(ShapeUtil::GetSubshape( + shaped_buffer->shape(), index)))); + shaped_buffer->buffers_.push_back(memory_base); + buffer_entry = shaped_buffer->buffers_.size() - 1; + } + + return std::move(shaped_buffer); +} + ShapedBuffer::ShapedBuffer(const Shape& shape, const se::Platform* platform, int device_ordinal) : shape_(shape), @@ -63,6 +93,14 @@ void ShapedBuffer::clear() { } } +void ShapedBuffer::AddBufferAtIndex( + const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& shape_index) { + *mutable_shape_index_to_buffer_entry()->mutable_element(shape_index) = + buffers().size(); + mutable_buffers()->push_back(buffer); +} + const se::DeviceMemoryBase& ShapedBuffer::buffer( const ShapeIndex& index) const { return buffers_[shape_index_to_buffer_entry_.element(index)]; @@ -72,67 +110,37 @@ se::DeviceMemoryBase* ShapedBuffer::mutable_buffer(const ShapeIndex& index) { return &buffers_[shape_index_to_buffer_entry_.element(index)]; } -/* static */ StatusOr> -ScopedShapedBuffer::Allocate(const Shape& shape, - DeviceMemoryAllocator* allocator, - int device_ordinal) { - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - auto shaped_buffer = - WrapUnique(new ScopedShapedBuffer(shape, allocator, device_ordinal)); - - // Allocate an appropriate sized buffer for each element in the shape - // including the tuple pointer arrays. Gather tuple element addresses in - // 'element_addresses'. These will be written in the respective tuple's array - // of pointers on the device. - TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, - TransferManager::GetForPlatform(allocator->platform())); - ShapeTree> element_addresses(shape); - for (auto& pair : shaped_buffer->shape_index_to_buffer_entry_) { - const ShapeIndex& index = pair.first; - size_t& buffer_entry = pair.second; - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase memory_base, - shaped_buffer->allocator_->Allocate( - shaped_buffer->device_ordinal(), - transfer_manager->GetByteSizeRequirement( - ShapeUtil::GetSubshape(shaped_buffer->shape(), index)))); - shaped_buffer->buffers_.push_back(memory_base); - buffer_entry = shaped_buffer->buffers_.size() - 1; - - // If this is a tuple element, then push the address on to the - // vector of tuple element addresses. - if (!index.empty()) { - ShapeIndex parent_index = index; - parent_index.pop_back(); - element_addresses.mutable_element(parent_index)->push_back(memory_base); - } - } - - // Fill in the tuple pointer arrays with the addresses of their respective - // elements. - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - allocator->platform()->ExecutorForDevice( - shaped_buffer->device_ordinal())); - for (const auto& pair : element_addresses) { - const ShapeIndex& index = pair.first; - const std::vector& addresses = pair.second; - const Shape& subshape = ShapeUtil::GetSubshape(shape, index); +string ShapedBuffer::ToString() const { + string s = "ShapedBuffer(" + platform_->Name() + "):\n"; + ShapeUtil::ForEachSubshape( + shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { + string shape_str; + if (ShapeUtil::IsTuple(subshape)) { + shape_str = "tuple"; + } else { + shape_str = ShapeUtil::HumanStringWithLayout(subshape); + } + const se::DeviceMemoryBase& memory = buffer(index); + Appendf(&s, " %s%p (%lld bytes) : %s\n", + string(index.size() * 2, ' ').c_str(), memory.opaque(), + memory.size(), shape_str.c_str()); + }); + return s; +} - if (addresses.empty()) { - TF_RET_CHECK(!ShapeUtil::IsTuple(subshape) || - ShapeUtil::TupleElementCount(subshape) == 0); - continue; - } - TF_RET_CHECK(ShapeUtil::IsTuple(subshape)); - TF_RETURN_IF_ERROR(transfer_manager->WriteTuplePointersToDevice( - executor, addresses, subshape, shaped_buffer->mutable_buffer(index))); - } +std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { + out << buffer.ToString(); + return out; +} - return std::move(shaped_buffer); +/* static */ StatusOr> +ScopedShapedBuffer::Allocate( + const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, + const std::function& shape_size_fn) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr unscoped_buffer, + ShapedBuffer::Allocate(shape, allocator, device_ordinal, shape_size_fn)); + return MakeScoped(unscoped_buffer.get(), allocator); } /* static */ diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index e5ea06fb136fa714eab0f340f98b7191a4c5caa3..ca8bfff674d2fad0fc5731cb2dc30b60bcf11997 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ #include +#include +#include #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/shape_tree.h" @@ -41,6 +43,14 @@ class ShapedBuffer { const Shape& shape, const perftools::gputools::Platform* platform, int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer); + // Return a newly allocated ShapedBuffer of an arbitrary shape. Array buffers + // (leaves in the shape) are allocated and uninitialized. Tuple buffers (if + // any) are allocated and initialized to the backend-specific representation + // of an array of pointers to the tuple elements. + static StatusOr> Allocate( + const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, + const std::function& shape_size_fn); + ShapedBuffer(const Shape& shape, const perftools::gputools::Platform* platform, int device_ordinal); @@ -75,6 +85,12 @@ class ShapedBuffer { // Set all device memory pointers in the object to null. void clear(); + // Adds a new buffer at the given shape index. + void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& shape_index); + + string ToString() const; + protected: // The shape of the device buffer with layout. const Shape shape_; @@ -95,17 +111,17 @@ class ShapedBuffer { ShapeTree shape_index_to_buffer_entry_; }; +std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); + // ShapedBuffer derived class which allocates all internal buffers on // construction and deallocates the memory when the object is // destructed. class ScopedShapedBuffer : public ShapedBuffer { public: - // Return a newly allocated ScopedShapedBuffer of an arbitrary shape. Array - // buffers (leaves in the shape) are allocated and uninitialized. Tuple - // buffers (if any) are allocated and initialized to the backend-specific - // representation of an array of pointers to the tuple elements. + // Identical to ShapedBuffer::Allocate. static StatusOr> Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal); + const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, + const std::function& shape_size_fn); // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the // deallocation of the device memory held in the shaped buffer. All device diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 4da0a0d36841a6dfaed5c7eebdfb9e6980ad1090..d5f53ad56fb019d0ae7c27fc28706f05614ece68 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -28,12 +28,9 @@ limitations under the License. namespace se = ::perftools::gputools; namespace xla { - -/* static */ tensorflow::mutex* -TransferManager::platform_transfer_manager_mutex() { - static tensorflow::mutex* m = new tensorflow::mutex; - return m; -} +/* static */ tensorflow::mutex + TransferManager::platform_transfer_manager_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* @@ -47,7 +44,7 @@ TransferManager::GetPlatformTransferManagers() { se::Platform::Id platform_id, TransferManagerCreationFunction creation_function) { tensorflow::mutex_lock lock( - *TransferManager::platform_transfer_manager_mutex()); + TransferManager::platform_transfer_manager_mutex_); auto* managers = GetPlatformTransferManagers(); CHECK(managers->find(platform_id) == managers->end()); (*managers)[platform_id].creation_function = creation_function; @@ -56,7 +53,7 @@ TransferManager::GetPlatformTransferManagers() { /* static */ StatusOr TransferManager::GetForPlatform( const se::Platform* platform) { tensorflow::mutex_lock lock( - *TransferManager::platform_transfer_manager_mutex()); + TransferManager::platform_transfer_manager_mutex_); auto* managers = GetPlatformTransferManagers(); auto it = managers->find(platform->id()); @@ -75,6 +72,39 @@ TransferManager::GetPlatformTransferManagers() { return it->second.manager.get(); } +Status TransferManager::WriteTupleIndexTables( + perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer) { + VLOG(2) << "Writing tuple index tables to ShapedBuffer rooted at " + << device_buffer.buffer(/*index=*/{}).opaque() + << "; shape: " << ShapeUtil::HumanString(device_buffer.shape()); + + TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + + return ShapeUtil::ForEachSubshapeWithStatus( + device_buffer.shape(), + [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { + if (ShapeUtil::IsTuple(device_subshape)) { + se::DeviceMemoryBase device_memory = device_buffer.buffer(index); + TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == + device_memory.size()); + + std::vector elements; + ShapeIndex element_index = index; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(device_subshape); + ++i) { + element_index.push_back(i); + elements.push_back(device_buffer.buffer(element_index)); + element_index.pop_back(); + } + return WriteTuplePointersToDevice(executor, elements, device_subshape, + &device_memory); + } + + return Status::OK(); + }); +} + Status TransferManager::TransferBufferFromDevice( se::StreamExecutor* executor, const se::DeviceMemoryBase& source, int64 size, void* destination) { diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 057bdffe93164e9bb7271157556961575666359d..be9b769ac8cf3cf1fcfd13dfe9f1458e55a5323d 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -47,6 +48,8 @@ class TransferManager { // executor. device_shape is the shape, including layout, of the data on the // device, while literal_shape will be the shape for the literal. device_shape // and literal_shape must be compatible, but need not have the same layout. + // TODO(b/66694934): Remove TransferLiteral* methods which accept bare + // DeviceMemoryBase. virtual Status TransferLiteralFromDevice( perftools::gputools::StreamExecutor* executor, const perftools::gputools::DeviceMemoryBase& region, @@ -59,6 +62,28 @@ class TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal, perftools::gputools::DeviceMemoryBase* region) = 0; + // Returns the shape of the on-device representation for the given shape on + // the host. This is intended for use with ShapedBuffer where buffers are + // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user + // needing to consider device-specific behaviors. + virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { + return host_shape; + } + + // Transfers the data held in the given ShapedBuffer into the provided literal + // using the provided executor. literal_shape will be the shape for the + // literal. The shape of the ShapedBuffer and DeviceShape(literal_shape) must + // be compatible, but need not have the same layout. + virtual StatusOr> TransferLiteralFromDevice( + perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer) = 0; + + // Transfers the given literal into the previously allocated device memory + // represented by the given ShapedBuffer using the given executor. + virtual Status TransferLiteralToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + const ShapedBuffer& device_buffer) = 0; + // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed( @@ -97,15 +122,11 @@ class TransferManager { const perftools::gputools::DeviceMemoryBase& source, const Shape& shape) = 0; - // Writes the given device-memory pointers in 'elements' to the given region - // to construct a tuple in the platform-specific tuple representation. This - // can handle nested tuples as well. In the nested case, the element - // DeviceMemoryBase points to another array of pointers on the device. - virtual Status WriteTuplePointersToDevice( - perftools::gputools::StreamExecutor* executor, - tensorflow::gtl::ArraySlice - elements, - const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0; + // Given an allocated ShapedBuffer, constructs the tuple index table(s) in + // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the + // ShapedBuffer is array-shaped this method does nothing. + Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer); // Returns all buffer pointers that the tuple `source` refers to. Unlike // ShallowCopyTupleFromDevice, this function gather buffer pointers in nested @@ -121,23 +142,6 @@ class TransferManager { // region for a host-to-device transfer. virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; - // Transfer a memory block of the given size from the device source into the - // 'destination' buffer. - // - // size is the size to transfer to destination in bytes. - virtual Status TransferBufferFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, int64 size, - void* destination); - - // Transfer a memory block of the given size from 'source' buffer to the given - // destination of the device. - // - // size is the size to transfer from source in bytes. - virtual Status TransferBufferToDevice( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source, perftools::gputools::DeviceMemoryBase* destination); - typedef std::unique_ptr (*TransferManagerCreationFunction)(); ///// @@ -157,12 +161,37 @@ class TransferManager { static StatusOr GetForPlatform( const perftools::gputools::Platform* platform); + protected: + // Transfer a memory block of the given size from the device source into the + // 'destination' buffer. + // + // size is the size to transfer to destination in bytes. + virtual Status TransferBufferFromDevice( + perftools::gputools::StreamExecutor* executor, + const perftools::gputools::DeviceMemoryBase& source, int64 size, + void* destination); + + // Transfer a memory block of the given size from 'source' buffer to the given + // destination of the device. + // + // size is the size to transfer from source in bytes. + virtual Status TransferBufferToDevice( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source, perftools::gputools::DeviceMemoryBase* destination); + + // Writes the given device-memory pointers in 'elements' to the given region + // to construct a tuple in the platform-specific tuple representation. This + // can handle nested tuples as well. In the nested case, the element + // DeviceMemoryBase points to another array of pointers on the device. + virtual Status WriteTuplePointersToDevice( + perftools::gputools::StreamExecutor* executor, + tensorflow::gtl::ArraySlice + elements, + const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0; + private: - // Routine that returns the mutex that guards the - // platform-to-transfer manager map. Done as a routine to - // ensure correct initialization ordering, since RegisterTransferManager - // can be called during program initialization time. - static tensorflow::mutex* platform_transfer_manager_mutex(); + // The mutex that guards the platform-to-transfer manager map. + static tensorflow::mutex platform_transfer_manager_mutex_; // State kept for each kind of TransferManager. Registration functions // set up creation_function, and then we use that to lazily create diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc deleted file mode 100644 index c25a0861e9b90bc0f2cde43933e14204aa4e3598..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/transfer_manager_test.cc +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" - -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" - -namespace se = ::perftools::gputools; - -namespace xla { - -namespace { - -class CpuTransferManagerTest : public ::testing::Test { - protected: - CpuTransferManagerTest() - : transfer_manager_(se::host::kHostPlatformId, - /*pointer_size=*/sizeof(void*)) { - se::Platform* platform = - se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) - .ValueOrDie(); - stream_exec_ = - platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie(); - } - - ~CpuTransferManagerTest() override {} - - se::StreamExecutor* stream_exec_; - GenericTransferManager transfer_manager_; -}; - -TEST_F(CpuTransferManagerTest, TransferR0U32ToDevice) { - std::vector storage(sizeof(uint32), '\x00'); - se::DeviceMemoryBase memptr(storage.data(), storage.size()); - std::unique_ptr literal = Literal::CreateR0(42); - TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, - &memptr)); - - CHECK_EQ(42, *reinterpret_cast(&storage[0])); -} - -TEST_F(CpuTransferManagerTest, TransferR1F32ToDevice) { - std::vector storage(4 * sizeof(float), '\x00'); - se::DeviceMemoryBase memptr(storage.data(), storage.size()); - std::unique_ptr literal = - Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); - TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, - &memptr)); - - CHECK_EQ(1.25f, *reinterpret_cast(&storage[0])); - CHECK_EQ(2.5f, *reinterpret_cast(&storage[sizeof(float)])); - CHECK_EQ(-17.0f, *reinterpret_cast(&storage[2 * sizeof(float)])); - CHECK_EQ(-20.125f, *reinterpret_cast(&storage[3 * sizeof(float)])); -} - -TEST_F(CpuTransferManagerTest, TransferR1U8ToDevice) { - std::vector storage(16, '\x00'); - se::DeviceMemoryBase memptr(storage.data(), storage.size()); - const char* str = "0123456789abcdef"; - std::unique_ptr literal = Literal::CreateR1U8(str); - TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, - &memptr)); - - CHECK_EQ('0', storage[0]); - CHECK_EQ('8', storage[8]); - CHECK_EQ('f', storage[15]); -} - -TEST_F(CpuTransferManagerTest, TransferR0U32FromDevice) { - std::vector storage(1, 42); - se::DeviceMemoryBase memptr(storage.data(), - storage.size() * sizeof(storage[0])); - Literal literal; - const Shape shape = ShapeUtil::MakeShape(U32, {}); - TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( - stream_exec_, memptr, shape, shape, &literal)); - - LiteralTestUtil::ExpectR0Equal(42, literal); -} - -TEST_F(CpuTransferManagerTest, TransferR1F32FromDevice) { - std::vector storage{1.25f, 2.5f, -17.0f, -20.125f}; - se::DeviceMemoryBase memptr(storage.data(), - storage.size() * sizeof(storage[0])); - Literal literal; - const Shape shape = ShapeUtil::MakeShape(F32, {4}); - TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( - stream_exec_, memptr, shape, shape, &literal)); - - LiteralTestUtil::ExpectR1Equal({1.25, 2.5, -17.0, -20.125}, literal); -} - -TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) { - std::vector storage{'k', 'l', 'm', 'n'}; - se::DeviceMemoryBase memptr(storage.data(), - storage.size() * sizeof(storage[0])); - Literal literal; - const Shape shape = ShapeUtil::MakeShape(U8, {4}); - TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( - stream_exec_, memptr, shape, shape, &literal)); - CHECK_EQ("klmn", literal.u8s_string()); -} - -TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) { - std::vector storage{1, 5, 42}; - int64 size = storage.size() * sizeof(storage[0]); - se::DeviceMemoryBase memptr(storage.data(), size); - - std::vector dest(3, 0); - TF_CHECK_OK(transfer_manager_.TransferBufferFromDevice(stream_exec_, memptr, - size, dest.data())); - ASSERT_EQ(1, dest[0]); - ASSERT_EQ(5, dest[1]); - ASSERT_EQ(42, dest[2]); -} - -TEST_F(CpuTransferManagerTest, TransferBufferToDevice) { - int64 size = 3 * sizeof(uint64); - std::vector storage(size, 0); - se::DeviceMemoryBase memptr(storage.data(), size); - - std::vector dest{1, 5, 42}; - TF_CHECK_OK(transfer_manager_.TransferBufferToDevice(stream_exec_, size, - dest.data(), &memptr)); - std::vector* storage64 = - reinterpret_cast*>(&storage); - ASSERT_EQ(1, (*storage64)[0]); - ASSERT_EQ(5, (*storage64)[1]); - ASSERT_EQ(42, (*storage64)[2]); -} - -// TODO(b/24679870): add similar tests for GPUs - -} // namespace - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 8c2640adf52f10c387e7a9c09c0d73a09c054919..83185ac49e9b7c386d10d1cbc4e20dcdfdfd6cae 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -42,7 +42,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( TransposeFolding::OperandIndices operand_set; for (int64 i = 0; i < dot.operand_count(); ++i) { auto& operand = *dot.operand(i); - if (operand.IsRank2Transpose() && operand.user_count() == 1) { + if (operand.IsRank2Transpose()) { operand_set.push_back(i); } } @@ -58,27 +58,10 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( return {}; } - const ConvolutionDimensionNumbers& dnums = - convolution.convolution_dimension_numbers(); - TransposeFolding::OperandIndices operand_set; for (int64 i = 0; i < convolution.operand_count(); ++i) { auto& operand = *convolution.operand(i); - if (operand.opcode() == HloOpcode::kTranspose && - operand.user_count() == 1) { - const auto& transpose_dimensions = operand.dimensions(); - // We can transpose the LHS so long as it doesn't move around spatial - // dimensions because ConvolutionDimensionNumbers doesn't have different - // fields for input and output spatial dimensions. - if (i == 0 && - std::any_of(dnums.spatial_dimensions().begin(), - dnums.spatial_dimensions().end(), - [&](const int64 spatial_dimension) { - return transpose_dimensions[spatial_dimension] != - spatial_dimension; - })) { - continue; - } + if (operand.opcode() == HloOpcode::kTranspose) { operand_set.push_back(i); } } @@ -118,6 +101,10 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; auto& operand_indices = pair.second; + if (operand_indices.empty()) { + return false; + } + const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); ConvolutionDimensionNumbers new_dnums = dnums; @@ -137,8 +124,9 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { transpose_dimensions[dnums.input_batch_dimension()]); new_dnums.set_input_feature_dimension( transpose_dimensions[dnums.input_feature_dimension()]); - for (const auto& spatial_dimension : dnums.spatial_dimensions()) { - CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]); + for (auto& input_spatial_dimension : + *new_dnums.mutable_input_spatial_dimensions()) { + input_spatial_dimension = transpose_dimensions[input_spatial_dimension]; } new_lhs = &transpose_operand; } else { diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 00462f9be1e9beb2f2694060ebfaa70b0b9dd4a0..caa1a111ad880b9dee62c1c94e32e8275c196fbf 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -64,9 +64,12 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -104,9 +107,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 3}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/transpose0, /*rhs=*/transpose1)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1, 3}), + /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -169,9 +175,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -362,10 +371,82 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { EXPECT_EQ( dnums.input_batch_dimension(), new_conv->convolution_dimension_numbers().input_feature_dimension()); - EXPECT_EQ(dnums.spatial_dimensions(0), - new_conv->convolution_dimension_numbers().spatial_dimensions(0)); - EXPECT_EQ(dnums.spatial_dimensions(1), - new_conv->convolution_dimension_numbers().spatial_dimensions(1)); + EXPECT_EQ( + dnums.input_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); + EXPECT_EQ( + dnums.input_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); + EXPECT_EQ( + dnums.output_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); +} + +// Test that a transpose of every dimension in the activations gets folded into +// convolution. +TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"y")); + HloInstruction* transpose_x = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr conv_shape = ShapeInference::InferConvolveShape( + transpose_x->shape(), y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the convolution. + std::unordered_set instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); + EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + EXPECT_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.input_feature_dimension(), + new_conv->convolution_dimension_numbers().input_batch_dimension()); + EXPECT_EQ( + dnums.input_batch_dimension(), + new_conv->convolution_dimension_numbers().input_feature_dimension()); + EXPECT_EQ( + dnums.input_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); + EXPECT_EQ( + dnums.input_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); } } // namespace diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index df537bd7c15a1f15ed77ca9be6ce70fbfd2e63be..0c848566478a25d4862cb0698e029dacd71f7a6a 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -120,6 +120,23 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, tree_.mutable_element(index)->tuple_sources.insert(tuple); } +namespace { + +// Gather fusion instructions from 'instruction' into 'fusion_instructions'. +void GatherFusionInstructions( + HloInstruction* instruction, + std::vector* fusion_instructions) { + CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); + for (auto* fused : instruction->fused_instructions()) { + if (fused->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(fused, fusion_instructions); + } + } + fusion_instructions->push_back(instruction); +} + +} // namespace + /* static */ StatusOr> TuplePointsToAnalysis::Run(const HloModule* module) { auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module); @@ -137,20 +154,23 @@ Status TuplePointsToAnalysis::Analyze() { logical_buffer_aliases_.resize( logical_buffer_analysis_->num_logical_buffers()); + std::vector fusion_instructions; for (auto* computation : module_->MakeNonfusionComputations()) { TF_RETURN_IF_ERROR(computation->Accept(this)); TF_RETURN_IF_ERROR( PopulateDefinedBuffersAndAliases(computation->instructions())); - // Run points-to analysis on fusion instructions in 'computation'. for (auto* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kFusion) { - continue; + if (instruction->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(instruction, &fusion_instructions); } - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); - TF_RETURN_IF_ERROR( - PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); } } + // Run points-to analysis on fusion instructions in 'computation'. + for (auto* instruction : fusion_instructions) { + TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + TF_RETURN_IF_ERROR( + PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); + } XLA_VLOG_LINES(3, ToString()); @@ -253,6 +273,64 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { + // RecvDone aliases its input (Recv) tuple element {0} to its output. + PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); + const PointsToSet& operand_points_to_set = + GetPointsToSet(recv_done->operand(0)); + + // Recursively copy the points to set of the operand tuple {0}. + points_to_set.ForEachMutableElement( + [this, &points_to_set, &operand_points_to_set]( + const ShapeIndex& index, PointsToSet::BufferList* buffers) { + ShapeIndex src_index({0}); + for (auto element : index) { + src_index.push_back(element); + } + *buffers = operand_points_to_set.element(src_index); + for (auto& tuple_source : + operand_points_to_set.tuple_sources(src_index)) { + points_to_set.add_tuple_source(index, tuple_source); + } + }); + return Status::OK(); +} + +Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { + // Send creates a tuple of {aliased operand, U32 context}. + PointsToSet& points_to_set = CreateEmptyPointsToSet(send); + + // Creates the points to set for the tuple and its element at {1}. + auto top_buffer = points_to_set.mutable_element(ShapeIndex({})); + top_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({}))); + points_to_set.add_tuple_source({}, send); + + auto context_buffer = points_to_set.mutable_element(ShapeIndex({1})); + context_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1}))); + + // Recursively copy the points to set of the operand to output tuple {0}. + const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0)); + operand_points_to_set.ForEachElement( + [&points_to_set, &operand_points_to_set]( + const ShapeIndex& src_index, + const PointsToSet::BufferList& points_to) { + ShapeIndex target_index({0}); + for (auto element : src_index) { + target_index.push_back(element); + } + *points_to_set.mutable_element(target_index) = points_to; + + for (HloInstruction* tuple : + operand_points_to_set.tuple_sources(src_index)) { + points_to_set.add_tuple_source(target_index, tuple); + } + }); + + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { tensorflow::gtl::ArraySlice operands(tuple->operands()); PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index e6157a1ed11b5df24458fe820a4e0e329eb86ae4..8928de107eed8c40bbe2130e26fe83ca3802d2f6 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -251,6 +251,8 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleCopy(HloInstruction* copy) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleSend(HloInstruction* send) override; Status HandleSelect(HloInstruction* select) override; string ToString() const; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 694ed57fa24d59bd0a28c7bb9b67af8165e90363..dec446d4dac650ba43992f7870764eedc80cb2cf 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -313,6 +313,51 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) { {constant1, constant2, copy}); } +TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { + // Send forwards its operand to the output tuple at {0}. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto send = builder.AddInstruction( + HloInstruction::CreateSend(constant, /*channel_id=*/0)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct()); + + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(send).element({}), {send}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(send).element({0}), {constant}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(), + {send_done}); + ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}}); +} + +TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { + // RecvDone forwards its operand tuple element at {0} to the output. + auto builder = HloComputation::Builder(TestName()); + auto recv = builder.AddInstruction(HloInstruction::CreateRecv( + ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct()); + + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(recv).element({}), {recv}); + ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}}); +} + TEST_F(TuplePointsToAnalysisTest, TupleSelect) { // Select from two different tuples. This should create an ambiguous points to // set containing the union of both sides. diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index e9d182509b5356d32b667b7921e2843d30faeb9b..e6893c8133b17cac3ca381df58d417eef15b60c4 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -88,8 +88,6 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kAtan2; case BINOP_COMPLEX: return HloOpcode::kComplex; - case BINOP_DOT: - return HloOpcode::kDot; case BINOP_MUL: return HloOpcode::kMultiply; case BINOP_ADD: @@ -765,6 +763,54 @@ StatusOr UserComputation::AddWhileInstruction( return handle; } +StatusOr UserComputation::AddConditionalInstruction( + const ConditionalRequest& conditional_request, + const UserComputation& true_computation, + const UserComputation& false_computation) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* pred, + LookUpRequest(conditional_request.predicate())); + TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, + LookUpRequest(conditional_request.true_operand())); + TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, + LookUpRequest(conditional_request.false_operand())); + + VersionedComputationHandle::Version true_computation_version = + true_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr true_computation_shape, + true_computation.ComputeProgramShape(true_computation_version)); + + VersionedComputationHandle::Version false_computation_version = + false_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr false_computation_shape, + false_computation.ComputeProgramShape(false_computation_version)); + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferConditionalShape( + pred->output_shape(), true_operand->output_shape(), + false_operand->output_shape(), + *true_computation_shape, *false_computation_shape)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + request.add_embedded_computation_versions(true_computation_version); + request.add_embedded_computation_versions(false_computation_version); + *request.mutable_request()->mutable_conditional_request() = + conditional_request; + + VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << conditional_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddBroadcastInstruction( const BroadcastRequest& broadcast_request) { tensorflow::mutex_lock lock(mutex_); @@ -994,6 +1040,32 @@ StatusOr UserComputation::AddConvertInstruction( return handle; } +StatusOr UserComputation::AddBitcastConvertInstruction( + const ConvertRequest& convert_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(convert_request.operand())); + + TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( + operand->output_shape(), + convert_request.new_element_type())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = new_shape; + *request.mutable_request()->mutable_bitcast_convert_request() = + convert_request; + + VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << convert_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddReducePrecisionInstruction( const ReducePrecisionRequest& reduce_precision_request) { tensorflow::mutex_lock lock(mutex_); @@ -1056,7 +1128,7 @@ StatusOr UserComputation::AddCrossReplicaSumInstruction( TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(cross_replica_sum_request.operand())); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - operand->output_shape())); + {&operand->output_shape()})); ComputationDataHandle handle = CreateComputationDataHandle(); @@ -1181,6 +1253,33 @@ StatusOr UserComputation::AddCustomCallInstruction( return handle; } +StatusOr UserComputation::AddDotInstruction( + const DotRequest& dot_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, + LookUpRequest(dot_request.lhs())); + TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, + LookUpRequest(dot_request.rhs())); + + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( + lhs->output_shape(), rhs->output_shape(), + dot_request.dimension_numbers())); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_dot_request() = dot_request; + + VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << dot_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddUnaryInstruction( const UnaryOpRequest& unary_request) { tensorflow::mutex_lock lock(mutex_); @@ -1603,6 +1702,15 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + PureFunctionalVisitor(session_computation, dot_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, dot_request.rhs(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::kSendRequest: { *is_functional = false; break; @@ -1713,6 +1821,14 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kBitcastConvertRequest: { + const ConvertRequest& convert_request = + request.request().bitcast_convert_request(); + PureFunctionalVisitor(session_computation, convert_request.operand(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); PureFunctionalVisitor(session_computation, while_request.init(), @@ -1723,6 +1839,23 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + PureFunctionalVisitor(session_computation, + conditional_request.predicate(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + conditional_request.true_operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + conditional_request.false_operand(), num_parameters, + visited, is_functional); + // TODO(b/32495713): We aren't checking the true and false computations + // themselves. + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -1951,6 +2084,21 @@ UserComputation::GetEmbeddedComputations( break; } + case OpRequest::kConditionalRequest: { + CHECK_EQ(2, request.embedded_computation_versions_size()); + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + const VersionedComputationHandle true_computation_versioned_handle = { + conditional_request.true_computation(), + request.embedded_computation_versions(0)}; + computations.push_back(true_computation_versioned_handle); + const VersionedComputationHandle false_computation_versioned_handle = + {conditional_request.false_computation(), + request.embedded_computation_versions(1)}; + computations.push_back(false_computation_versioned_handle); + break; + } + default: // No embedded computation. break; @@ -2037,6 +2185,16 @@ Status UserComputation::RemapEmbeddedComputations( TF_RETURN_IF_ERROR(update(while_request->mutable_body())); break; } + case OpRequest::kConditionalRequest: { + TF_RET_CHECK(2 == request.embedded_computation_versions_size()); + ConditionalRequest* conditional_request = + request.mutable_request()->mutable_conditional_request(); + TF_RETURN_IF_ERROR( + update(conditional_request->mutable_true_computation())); + TF_RETURN_IF_ERROR( + update(conditional_request->mutable_false_computation())); + break; + } default: // No embedded computation. TF_RET_CHECK(0 == request.embedded_computation_versions_size()); @@ -2370,12 +2528,28 @@ static void ForEachOperand( break; } + case OpRequest::kBitcastConvertRequest: { + const ConvertRequest& convert_request = + request.request().bitcast_convert_request(); + apply(convert_request.operand()); + break; + } + case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); apply(while_request.init()); break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + apply(conditional_request.predicate()); + apply(conditional_request.true_operand()); + apply(conditional_request.false_operand()); + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -2412,6 +2586,13 @@ static void ForEachOperand( break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + apply(dot_request.rhs()); + apply(dot_request.lhs()); + break; + } + case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); @@ -2538,6 +2719,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( if (ShapeUtil::IsScalar(operand->shape())) { HloInstruction* broadcast = hlo_builder_.AddInstruction( HloInstruction::CreateBroadcast(broadcast_shape, operand, {})); + broadcast->set_metadata(operand->metadata()); if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } @@ -2558,6 +2740,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( ShapeUtil::MakeShape(operand->shape().element_type(), reshaped_dimensions), operand)); + reshaped_operand->set_metadata(operand->metadata()); if (operand->has_sharding()) { reshaped_operand->set_sharding(operand->sharding()); } @@ -2565,6 +2748,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( HloInstruction* broadcast = hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( broadcast_shape, reshaped_operand, broadcast_dimensions)); + broadcast->set_metadata(operand->metadata()); if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } @@ -2688,13 +2872,22 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + HloInstruction* lhs = lookup_instruction(dot_request.lhs()); + HloInstruction* rhs = lookup_instruction(dot_request.rhs()); + hlo_instruction = add_instruction(HloInstruction::CreateDot( + request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); + break; + } + case OpRequest::kCrossReplicaSumRequest: { const CrossReplicaSumRequest& cross_replica_sum_request = request.request().cross_replica_sum_request(); HloInstruction* operand = lookup_instruction(cross_replica_sum_request.operand()); hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), operand)); + request.output_shape(), {operand})); break; } @@ -2927,8 +3120,9 @@ void ComputationLowerer::Visit( case OpRequest::kRecvRequest: { const RecvRequest& recv_request = request.request().recv_request(); - hlo_instruction = add_instruction(HloInstruction::CreateRecv( + HloInstruction* recv = add_instruction(HloInstruction::CreateRecv( request.output_shape(), recv_request.channel_handle().handle())); + hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv)); break; } @@ -2950,6 +3144,15 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kBitcastConvertRequest: { + const ConvertRequest& convert_request = + request.request().bitcast_convert_request(); + HloInstruction* operand = lookup_instruction(convert_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert( + request.output_shape(), operand)); + break; + } + case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); CHECK_EQ(2, request.embedded_computation_versions_size()); @@ -2967,6 +3170,30 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + CHECK_EQ(2, request.embedded_computation_versions_size()); + VersionedComputationHandle::Version true_computation_version = + request.embedded_computation_versions(0); + HloComputation* true_computation = ResolveComputation( + conditional_request.true_computation(), true_computation_version); + VersionedComputationHandle::Version false_computation_version = + request.embedded_computation_versions(1); + HloComputation* false_computation = ResolveComputation( + conditional_request.false_computation(), false_computation_version); + HloInstruction* predicate = + lookup_instruction(conditional_request.predicate()); + HloInstruction* true_operand = + lookup_instruction(conditional_request.true_operand()); + HloInstruction* false_operand = + lookup_instruction(conditional_request.false_operand()); + hlo_instruction = add_instruction(HloInstruction::CreateConditional( + request.output_shape(), predicate, true_operand, true_computation, + false_operand, false_computation)); + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -2974,6 +3201,25 @@ void ComputationLowerer::Visit( HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); + + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { + if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { + // lhs side is being implicitly broadcast. Change to explicit. + lhs = + ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); + } + + if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { + rhs = + ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); + } + + if (!ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { + ehs = + ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); + } + } + hlo_instruction = add_instruction(HloInstruction::CreateTernary( request.output_shape(), hlo_opcode, lhs, rhs, ehs)); break; @@ -3078,8 +3324,7 @@ void ComputationLowerer::Visit( lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && - binary_op_request.binop() != BINOP_DOT) { + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { // lhs side is being implicitly broadcast. Change to explicit. lhs = @@ -3120,8 +3365,9 @@ void ComputationLowerer::Visit( case OpRequest::kSendRequest: { const SendRequest& send_request = request.request().send_request(); HloInstruction* operand = lookup_instruction(send_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateSend( + HloInstruction* send = add_instruction(HloInstruction::CreateSend( operand, send_request.channel_handle().handle())); + hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send)); break; } @@ -3132,7 +3378,7 @@ void ComputationLowerer::Visit( LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); } (*instructions)[handle.handle()] = hlo_instruction; -} +} // NOLINT(readability/fn_size) } // namespace diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index ac879ce55a75f6241a39f935b79017be46c1816b..8a78d520e19024f5e397d6e0c2f4e0523264e176 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -70,7 +70,7 @@ class UserComputation { // Enqueues a pad instruction onto this user computation. StatusOr AddPadInstruction( - const PadRequest& parameter_request); + const PadRequest& pad_request); // Enqueues a tracing instruction onto this user computation. // Returns an error status if the operand cannot be resolved. @@ -105,7 +105,7 @@ class UserComputation { // Enqueues a ternary instruction onto this user computation. // Returns an error status if the operand indices are out of bounds. StatusOr AddTernaryInstruction( - const TernaryOpRequest& request); + const TernaryOpRequest& ternary_request); // Enqueues a variadic instruction onto this user computation. // Returns an error status if the operand indices are out of bounds. @@ -153,6 +153,10 @@ class UserComputation { StatusOr AddCustomCallInstruction( const CustomCallRequest& custom_call_request); + // Enqueues a dot instruction onto this user computation. + StatusOr AddDotInstruction( + const DotRequest& dot_request); + // Enqueues a broadcast instruction onto this user computation. StatusOr AddBroadcastInstruction( const BroadcastRequest& broadcast_request); @@ -179,26 +183,30 @@ class UserComputation { // Enqueues a concatenate instruction onto this user computation. StatusOr AddConcatenateInstruction( - const ConcatenateRequest& slice_request); + const ConcatenateRequest& concatenate_request); // Enqueues a convert instruction onto this user computation. StatusOr AddConvertInstruction( const ConvertRequest& convert_request); + // Enqueues a bitcast element instruction onto this user computation. + StatusOr AddBitcastConvertInstruction( + const ConvertRequest& convert_request); + // Enqueues a reduce instruction onto this user computation. StatusOr AddReduceInstruction( const ReduceRequest& reduce_request, - const UserComputation& reduction_computation); + const UserComputation& to_apply_computation); // Enqueues a windowed reduce instruction onto this user computation. StatusOr AddReduceWindowInstruction( const ReduceWindowRequest& reduce_window_request, - const UserComputation& reduction_computation); + const UserComputation& to_apply_computation); // Enqueues a select-and-scatter instruction onto this user // computation. StatusOr AddSelectAndScatterInstruction( - const SelectAndScatterRequest& scatter_to_selected_window_element_request, + const SelectAndScatterRequest& select_and_scatter_request, const UserComputation& select_computation, const UserComputation& scatter_computation); @@ -212,6 +220,12 @@ class UserComputation { const UserComputation& condition_computation, const UserComputation& body_computation); + // Enqueues a conditional instruction on this user computation. + StatusOr AddConditionalInstruction( + const ConditionalRequest& conditional_request, + const UserComputation& true_computation, + const UserComputation& false_computation); + // Enqueues a Send instruction onto this user computation. Status AddSendInstruction(const SendRequest& send_request); diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 5afaf226ae0cce7e9afc966c6b4adf838aeebc91..e45673300b6c5f85be4153f2db821d8abbced7cd 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -334,50 +334,5 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { operands[1]->opcode() == HloOpcode::kBroadcast); } -TEST_F(UserComputationTest, SkipDotInEliminatingImplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // %a = Param({1, 3}); - // %b = Param({3, 1}); - // %dot = Dot(%a, %b); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {3, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest dot; - dot.set_binop(BINOP_DOT); - *dot.mutable_lhs() = a_handle; - *dot.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(dot).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - EXPECT_EQ(3, hlo_computation->instruction_count()); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 65734f91bc6ce5d9fa00dae22544dd1f169d861c..b2fd64a4d9f3dc343b2e44b5efa31aacc6085042 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -58,7 +58,9 @@ static bool ContainsSendOrRecv(const HloComputation* comp) { static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { if (instr->opcode() == HloOpcode::kSend || - instr->opcode() == HloOpcode::kRecv) { + instr->opcode() == HloOpcode::kSendDone || + instr->opcode() == HloOpcode::kRecv || + instr->opcode() == HloOpcode::kRecvDone) { return true; } for (const auto& subcomp : instr->called_computations()) { @@ -287,7 +289,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Don't try this transformation if the while loop isn't removable, since if // it succeeds ultimately we're going to have to replace the old while loop // with a new one. - if (!while_op->parent()->IsRemovable(while_op)) { + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { VLOG(2) << "Can't remove dead parameters from non-removable while op."; return false; } @@ -340,7 +342,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // // Careful: HloInstruction::operand_index returns the first index the // operand appears in, but it may appear more than once! - if (user->user_count() == 1 && user->users()[0] == while_body_root && + if (user->user_count() == 1 && user->users().front() == while_body_root && while_body_root->operand_index(user) == user->tuple_index() && std::count(while_body_root->operands().begin(), while_body_root->operands().end(), user) == 1) { @@ -401,6 +403,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Compute the shape of the while op after we remove the dead indices. std::vector new_while_tuple_elem_shapes; + new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size()); for (int64 old_idx : new_to_old_tuple_idx) { new_while_tuple_elem_shapes.push_back( while_init->shape().tuple_shapes(old_idx)); @@ -441,7 +444,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // This is a GTE of an index that we've removed. Remove it from the // cloned computation. CHECK(user->user_count() == 0 || - user->user_count() == 1 && user->users()[0] == while_body_root) + user->user_count() == 1 && + user->users().front() == while_body_root) << "Instruction " << user->ToStringNoMetadata() << " should be unused (except by root of while body), but has " "users: {" @@ -467,6 +471,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { std::unordered_map> while_body_replacements = make_while_computation_replacements(while_body); std::vector new_while_body_root_elems; + new_while_body_root_elems.reserve(new_to_old_tuple_idx.size()); for (int64 old_idx : new_to_old_tuple_idx) { new_while_body_root_elems.push_back( while_body_root->mutable_operand(old_idx)); @@ -481,6 +486,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // clean this up in the common case where while_init is a tuple op. (It's // definitely tuple-shaped, but it's not necessarily a tuple op.) std::vector new_while_init_elems; + new_while_init_elems.reserve(new_to_old_tuple_idx.size()); for (int64 old_idx : new_to_old_tuple_idx) { new_while_init_elems.push_back( computation->AddInstruction(HloInstruction::CreateGetTupleElement( @@ -552,7 +558,7 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { // the loop aren't removed, just cloned and added back to the loop. // Nevertheless our infrastructure sees loop simplification as removal of // these nodes and currently doesn't allow it. - if (!while_op->parent()->IsRemovable(while_op)) { + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { VLOG(2) << "Not attempting to remove while loop it is not removable: " << while_op->ToShortString(); return false; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 8e1a2dcde129e9a022789eb7b192319901b9db4a..d99b31dc0037968bc88d5f22d53309a6a4546963 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -144,10 +144,11 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); - while_body->AddInstruction(HloInstruction::CreateSend( + auto* send = while_body->AddInstruction(HloInstruction::CreateSend( while_body->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(true))), /*channel_id=*/0)); + while_body->AddInstruction(HloInstruction::CreateSendDone(send)); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } @@ -156,9 +157,10 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); - while_body->AddInstruction( + auto* recv = while_body->AddInstruction( HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); + while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 5bf9842a6ce7be747f58c10f302f85c6f82ac6f9..789eba5780d37e1fd4d80ec881855951c8bba0eb 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -32,13 +32,13 @@ tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { return tensorflow::Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* other_shape) const { - if (!ShapeUtil::Compatible(*other_shape, shape_)) { +tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { + if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(*other_shape).c_str(), + ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } - *other_shape = shape_; + *to_shape = shape_; return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 92564660f21bf1b596c4b9ca04c07eaca27ed192..4c83750f3e6f3c735db66d8e0b86ae3f43e5ca11 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -38,18 +38,19 @@ class ShapeLayout { explicit ShapeLayout(const Shape& shape) : shape_(shape) {} // Assigns the layouts in this ShapeLayout to the Layout fields of the given - // shape. 'shape' and the shape of the ShapeLayout object must be compatible. - tensorflow::Status AssignLayoutToShape(Shape* shape) const; + // shape. 'to_shape' and the shape of the ShapeLayout object must be + // compatible. + tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible // with the ShapeLayout's shape, then false is returned. bool MatchesLayoutInShape(const Shape& shape) const; - // Copies the layout from the given shape into this ShapeLayout. 'shape' must - // be compatible with the ShapeLayout's shape, and 'shape' must have a layout - // (LayoutUtil::HasLayout). - tensorflow::Status CopyLayoutFromShape(const Shape& shape); + // Copies the layout from the given shape into this ShapeLayout. 'other_shape' + // must be compatible with the ShapeLayout's shape, and 'other_shape' must + // have a layout (LayoutUtil::HasLayout). + tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 64a36471b9f1b35517c29c01554e02c5d1035086..d752619bd65751779c24f061e44e206d66b01465 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -116,6 +116,7 @@ class ShapeTree { ShapeTree(const Shape* shape, const T& init_value); ShapeTree(const ShapeTree& other) { *this = other; } + ShapeTree(ShapeTree&&) = default; ShapeTree& operator=(const ShapeTree& other) { root_ = other.root_; @@ -132,6 +133,8 @@ class ShapeTree { return *this; } + ShapeTree& operator=(ShapeTree&& other) = default; + // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). const T& element(const ShapeIndex& index) const; @@ -152,28 +155,57 @@ class ShapeTree { using const_iterator = ShapeTreeIterator; // begin/end for iterating over all nodes. - iterator begin() { return iterator(&root_, /*iterate_leaves_only=*/false); } - iterator end() { return iterator(nullptr, /*iterate_leaves_only=*/false); } + iterator begin() { + return iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/false); + } + iterator end() { + return iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/false); + } const_iterator begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false); + return const_iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/false); } const_iterator end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false); + return const_iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/false); + } + + // rbegin/rend for iterating over all nodes in reverse. + iterator rbegin() { + return iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/true); + } + iterator rend() { + return iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/true); + } + const_iterator rbegin() const { + return const_iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/true); + } + const_iterator rend() const { + return const_iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/true); } // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). iterator leaf_begin() { - return iterator(&root_, /*iterate_leaves_only=*/true); + return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/false); } iterator leaf_end() { - return iterator(nullptr, /*iterate_leaves_only=*/true); + return iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/false); } const_iterator leaf_begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true); + return const_iterator(&root_, /*iterate_leaves_only=*/true, + /*reverse=*/false); } const_iterator leaf_end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true); + return const_iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/false); } // range-based iterator for leaf_begin()/leaf_end(). tensorflow::gtl::iterator_range leaves() { @@ -183,6 +215,22 @@ class ShapeTree { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } + iterator leaf_rbegin() { + return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/true); + } + iterator leaf_rend() { + return iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/true); + } + const_iterator leaf_rbegin() const { + return const_iterator(&root_, /*iterate_leaves_only=*/true, + /*reverse=*/true); + } + const_iterator leaf_rend() const { + return const_iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/true); + } + // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // @@ -190,7 +238,7 @@ class ShapeTree { // (or compatible). // index : the index of the element in the shape. See ShapeUtil::GetSubshape // for definition of index. - // data : The data value at this elemnt. + // data : The data value at this element. template void ForEachElement(const Fn& func) const; @@ -277,42 +325,61 @@ class ShapeTreeIterator : public std::iteratorchildren.empty() && iterate_leaves_only) { - ++*this; + // interior tree nodes, only leaves. If reverse is true, the iterator will + // visit nodes in the reverse of pre-order traversal. + ShapeTreeIterator(NodeType* node, bool iterate_leaves_only, bool reverse) + : node_(node), + iterate_leaves_only_(iterate_leaves_only), + reverse_(reverse) { + if (node_) { + if (reverse_) { + while (!node_->children.empty()) { + const int child_index = node_->children.size() - 1; + stack_.push_back({node_, child_index}); + node_ = node_->children[child_index].get(); + } + } else { + if (!node_->children.empty() && iterate_leaves_only) { + ++*this; + } + } } } ShapeTreeIterator(const ShapeTreeIterator& other) : node_(other.node_), stack_(other.stack_), - iterate_leaves_only_(other.iterate_leaves_only_) {} + iterate_leaves_only_(other.iterate_leaves_only_), + reverse_(other.reverse_) {} ShapeTreeIterator& operator++() { CHECK_NE(nullptr, node_) << "walking off the end() of an iterator!"; - // We're doing a pre-order walk, so if our current node has children take - // the first child. - if (!node_->children.empty()) { - stack_.push_back({node_, /*child-index=*/0}); - node_ = node_->children[0].get(); - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); + if (reverse_) { + while (!stack_.empty()) { + node_ = stack_.back().first; + int64 next_child_index = stack_.back().second - 1; + stack_.pop_back(); + if (next_child_index < 0) { + if (!iterate_leaves_only_) { + // All children are visited, yield . + return *this; + } + } else { + stack_.push_back({node_, next_child_index}); + node_ = node_->children[next_child_index].get(); + while (!node_->children.empty()) { + const int child_index = node_->children.size() - 1; + stack_.push_back({node_, child_index}); + node_ = node_->children[child_index].get(); + } + return *this; + } } - } - // Otherwise we are currently at a leaf. Walk back up until a node contains - // a child we haven't visited yet. - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second + 1; - stack_.pop_back(); - if (node_->children.size() > next_child_index) { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - + } else { + // We're doing a pre-order walk, so if our current node has children take + // the first child. + if (!node_->children.empty()) { + stack_.push_back({node_, /*child-index=*/0}); + node_ = node_->children[0].get(); if (node_->children.empty() || !iterate_leaves_only_) { return *this; } else { @@ -320,6 +387,24 @@ class ShapeTreeIterator : public std::iteratorchildren.size() > next_child_index) { + stack_.push_back({node_, next_child_index}); + node_ = node_->children[next_child_index].get(); + + if (node_->children.empty() || !iterate_leaves_only_) { + return *this; + } else { + // This is a non-leaf; tail-recurse. + return ++(*this); + } + } + } } // We've walked off the end of the tree. Set node_ to nullptr to signify // end(). @@ -361,6 +446,8 @@ class ShapeTreeIterator : public std::iterator> stack_; // True if we should not include interior nodes in our walk. bool iterate_leaves_only_; + // True if we should yield the reverse of the pre-order traversal. + bool reverse_; // Placeholder for the current value. Ideally this wouldn't exist and would // just be an rvalue, but operator -> needs to return a pointer to something. // We cannot just use a plain old value_type as it contains a reference so diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 7b4b5cb0fb5e1564ca12ac6e3b901e94ea4c8db6..4b6ab772811f4a6c6ffc1d10befc7122f883b8f9 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -456,6 +456,26 @@ TEST_F(ShapeTreeTest, IterateOrder) { {2, 1}})); } +TEST_F(ShapeTreeTest, ReverseIterateOrder) { + ShapeTree t(nested_tuple_shape_, 42); + std::vector v; + for (auto it = t.rbegin(); it != t.rend(); ++it) { + v.push_back(it->first); + } + EXPECT_EQ(v, (std::vector{ + {2, 1}, + {2, 0, 1}, + {2, 0, 0}, + {2, 0}, + {2}, + {1, 1}, + {1, 0}, + {1}, + {0}, + {}, + })); +} + TEST_F(ShapeTreeTest, IterateOrderLeaves) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; @@ -466,5 +486,21 @@ TEST_F(ShapeTreeTest, IterateOrderLeaves) { {0}, {1, 0}, {1, 1}, {2, 0, 0}, {2, 0, 1}, {2, 1}})); } +TEST_F(ShapeTreeTest, ReverseIterateOrderLeaves) { + ShapeTree t(nested_tuple_shape_, 42); + std::vector v; + for (auto it = t.leaf_rbegin(); it != t.leaf_rend(); ++it) { + v.push_back(it->first); + } + EXPECT_EQ(v, (std::vector{ + {2, 1}, + {2, 0, 1}, + {2, 0, 0}, + {1, 1}, + {1, 0}, + {0}, + })); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index b5eb81dfc6a4117909dcb18fdbe61443b1a1eb95..fe5166643df573ab8cbbea56ac791bccf5b7a4a8 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -263,6 +264,7 @@ StatusOr MakeShapeWithLayoutInternal( case S32: case S64: case F16: + case BF16: case F32: case F64: return true; @@ -328,6 +330,14 @@ StatusOr MakeShapeWithLayoutInternal( return MakeTupleShape(new_elements); } +// Returns the shape of a real or imaginary component. +/* static */ Shape ShapeUtil::ComplexComponentShape( + const Shape& complex_shape) { + CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape); + return ChangeElementType(complex_shape, primitive_util::ComplexComponentType( + complex_shape.element_type())); +} + /* static */ bool ShapeUtil::ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions) { @@ -395,6 +405,26 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) { static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator(); return gen->LowercaseName(s); } + +StatusOr StringToPrimitiveType(const string& name) { + static std::unordered_map* name_to_type = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i)) { + auto value = static_cast(i); + (*map)[LowercasePrimitiveTypeName(value)] = value; + } + } + return map; + }(); + auto found = name_to_type->find(name); + if (found == name_to_type->end()) { + return InvalidArgument("Invalid element type string: \"%s\".", + name.c_str()); + } + return found->second; +} + } // namespace /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { @@ -499,17 +529,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { comma_list_to_int64s(dimensions_string)); // Extract the primitive element type. - PrimitiveType primitive_type = PRIMITIVE_TYPE_INVALID; - for (PrimitiveType i = - static_cast(PRIMITIVE_TYPE_INVALID + 1); - i < TUPLE; i = static_cast(i + 1)) { - if (tensorflow::str_util::Lowercase(PrimitiveType_Name(i)) == - element_type_string) { - primitive_type = i; - break; - } - } - if (primitive_type == PRIMITIVE_TYPE_INVALID) { + TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, + StringToPrimitiveType(element_type_string)); + if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || + primitive_type == OPAQUE) { return InvalidArgument("Invalid element type string: \"%s\".", element_type_string.c_str()); } @@ -552,6 +575,16 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs); } +/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, + const Shape& rhs) { + if (lhs.element_type() == TUPLE) { + return rhs.element_type() == TUPLE && + ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringElementType); + } + return SameDimensions(lhs, rhs); +} + /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, int64 dimension_number) { return shape.dimensions(GetDimensionNumber(shape, dimension_number)); @@ -591,6 +624,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(uint32); case U64: return sizeof(uint64); + case BF16: + return sizeof(float) / 2; case F16: return sizeof(float) / 2; case F32: @@ -681,9 +716,9 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return LayoutUtil::ValidateLayoutInShape(shape); } -/* static */ Shape ShapeUtil::ChangeElementType(const Shape& shape, +/* static */ Shape ShapeUtil::ChangeElementType(const Shape& original, PrimitiveType type) { - Shape new_shape = shape; + Shape new_shape = original; new_shape.set_element_type(type); return new_shape; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 8f8d4a73c9ecb3f4236f3877323ad1127bb0b9c2..666c7da697c7cbad4dc30a7b3feb2b2804562442 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -68,6 +68,9 @@ class ShapeIndex { const int64* data() const { return indices_.data(); } + int64 back() const { return indices_.back(); } + int64& back() { return indices_.back(); } + const int64& operator[](size_t i) const { return indices_[i]; } int64& operator[](size_t i) { return indices_[i]; } @@ -167,7 +170,7 @@ class ShapeUtil { // As above, but for program shapes, returns a string for the form: // // (param_name: f32[42x12], ...) -> f32[24x42] - static string HumanString(const ProgramShape& shape); + static string HumanString(const ProgramShape& program_shape); // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. @@ -187,6 +190,11 @@ class ShapeUtil { // compatibility. static bool Compatible(const Shape& lhs, const Shape& rhs); + // Returns true if the rank and dimension sizes are identical. Element type + // and layout are ignored. Tuple elements are compared recursively for + // compatibility. + static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); + // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); @@ -343,6 +351,10 @@ class ShapeUtil { // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); + // Returns the shape of the real/imaginary components of the given complex + // shape. + static Shape ComplexComponentShape(const Shape& complex_shape); + // Shorthand for testing whether a shape is of a given element type and // sequence of dimensions. // diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 0ba542ad1bec290c35c52a8dd5177893770310fd..4bce7ca51d0534cbcad6faac12818c5f3e94b29e 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -145,6 +145,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { Shape tuple2 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { @@ -153,6 +154,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { Shape tuple2 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(S32, {3, 2})}); EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); + EXPECT_TRUE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) { diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index 5fa2211ac66177514ac8ecabfa8791e7c8c014a2..f9d25945bc617507735fb6c4d011c39723497f69 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -32,26 +32,26 @@ namespace { class Base1 { public: virtual ~Base1() {} - int pad; + int pad_; }; class Base2 { public: virtual ~Base2() {} - int yetotherpad; + int yetotherpad_; }; class Derived : public Base1, public Base2 { public: ~Derived() override {} - int evenmorepad; + int evenmorepad_; }; class CopyNoAssign { public: - explicit CopyNoAssign(int value) : foo(value) {} - CopyNoAssign(const CopyNoAssign& other) : foo(other.foo) {} - int foo; + explicit CopyNoAssign(int value) : foo_(value) {} + CopyNoAssign(const CopyNoAssign& other) : foo_(other.foo_) {} + int foo_; private: const CopyNoAssign& operator=(const CopyNoAssign&); @@ -253,7 +253,7 @@ TEST(StatusOr, TestCopyCtorNonAssignable) { StatusOr original(value); StatusOr copy(original); EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.ValueOrDie().foo, copy.ValueOrDie().foo); + EXPECT_EQ(original.ValueOrDie().foo_, copy.ValueOrDie().foo_); } TEST(StatusOr, TestCopyCtorStatusOKConverting) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 4e1be24b61cc436b0baf62cc6e28ad8d13fe71ac..6af01ae80d9ac8cdf8e7ba5cff4c24ef1d31cf94 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -61,14 +61,19 @@ generate_backend_test_macros() cc_library( name = "test_utils", - testonly = True, + srcs = ["test_utils.cc"], hdrs = ["test_utils.h"], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_headers_lib", ], ) @@ -100,7 +105,9 @@ cc_library( hdrs = ["hlo_test_base.h"], deps = [ ":literal_test_util", + ":test_utils", "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -110,6 +117,9 @@ cc_library( "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -427,6 +437,27 @@ xla_test( ], ) +xla_test( + name = "conditional_test", + srcs = ["conditional_test.cc"], + # Currently, Conditional is supported only in CPU and GPU backends. + backends = [ + "cpu", + "gpu", + ], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], @@ -508,6 +539,7 @@ xla_test( name = "array_elementwise_ops_test", srcs = ["array_elementwise_ops_test.cc"], shard_count = 25, + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -766,6 +798,41 @@ xla_test( ], ) +xla_test( + name = "bfloat16_test", + srcs = ["bfloat16_test.cc"], + blacklisted_backends = [ + "gpu", + ], + shard_count = 40, + deps = [ + ":test_utils", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "slice_test", srcs = ["slice_test.cc"], @@ -1226,6 +1293,23 @@ xla_test( ], ) +xla_test( + name = "bitcast_convert_test", + srcs = ["bitcast_convert_test.cc"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + xla_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.cc"], @@ -1290,6 +1374,7 @@ xla_test( srcs = ["client_test.cc"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", @@ -1343,22 +1428,23 @@ xla_test( ], ) -xla_test( +tf_cc_test( name = "llvm_compiler_test", srcs = ["llvm_compiler_test.cc"], - backends = [ - "cpu", - "gpu", - "cpu_parallel", - ], + tags = ["requires-gpu-sm35"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:llvm_compiler", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/gpu:gpu_compiler", "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor", "@llvm//:core", ], ) @@ -1596,6 +1682,65 @@ tf_cc_test( ], ) +xla_test( + name = "transfer_manager_test", + srcs = ["transfer_manager_test.cc"], + deps = [ + ":literal_test_util", + ":local_client_test_base", + ":xla_internal_test_main", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:generic_transfer_manager", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +# A demo of textual IR based test. +xla_test( + name = "sample_text_test", + srcs = ["sample_text_test.cc"], + # You can leave this empty if you want to test all supported backends. + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +# A demo of test that loads an hlo module from a file and compares results on gpu and cpu. +tf_cc_test( + name = "sample_file_test", + srcs = ["sample_file_test.cc"], + data = ["isolated_convolution.hlo"], + tags = ["requires-gpu-sm35"], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:cpu_plugin", # reference backend + "//tensorflow/compiler/xla/service:gpu_plugin", # test backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 0b700fbb6ffbde147c71b76d37f334a53c91f2fd..c6e8b24d1211743d07878d388522feacf9c0e7f1 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -82,6 +82,25 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { {}); } +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto result = builder.Neg(a); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); + auto result = builder.Neg(a); + + ComputeAndCompareR1( + &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, + {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); @@ -145,6 +164,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { ComputeAndCompareR1(&builder, {}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); + auto b = builder.ConstantR1( + {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1( + &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); ComputationBuilder builder(client_, TestName()); @@ -222,6 +263,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); + auto b = builder.ConstantR1( + {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1( + &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); @@ -385,6 +448,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } } +XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); + auto b = builder.ConstantR1( + {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); + auto div = builder.Div(a, b); + + ComputeAndCompareR1( + &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto div = builder.Div(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1( @@ -496,6 +580,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); + auto b = builder.ConstantR1( + {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1( + &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, false, true, true}); @@ -886,6 +992,53 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { + SetFastMathDisabled(true); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = builder.ConstantR1({{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { + // Disable fast-math because we're operating on NaNs. + SetFastMathDisabled(true); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = builder.ConstantR1({{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + auto compare = builder.Ne(lhs, rhs); + + ComputeAndCompareR1(&builder, {true, true, false, true, true}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 028d1251b455b82a291c236f7866e52e27d3590e..7525bc4bdfbaa942ea8af29af31829ae8742e833 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -203,6 +204,15 @@ struct BatchNormTestParam { int64 feature_index; float random_value_mean; float random_value_var; + + friend ::std::ostream& operator<<(::std::ostream& os, + const BatchNormTestParam& p) { + os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; + os << "feature_index=" << p.feature_index << ", "; + os << "random_value_mean=" << p.random_value_mean << ", "; + os << "random_value_var=" << p.random_value_var; + return os; + } }; // Tests to test the fused operation of BatchNorm. diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac3f3f4c9ddb03d003a44f5abd7a2e26c42f490d --- /dev/null +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class Bfloat16Test : public ClientLibraryTestBase { + protected: + const ErrorSpec error_spec_{0.001, 0.001}; +}; + +XLA_TEST_F(Bfloat16Test, ScalarOperation) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0(static_cast(2.0f)); + auto y = builder.ConstantR0(static_cast(1.0f)); + builder.Add(x, y); + + ComputeAndCompareR0(&builder, static_cast(3.0f), {}, + error_spec_); +} + +XLA_TEST_F(Bfloat16Test, LogOperation) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0(static_cast(4.0f)); + builder.Log(x); + + ComputeAndCompareR0(&builder, static_cast(1.387f), {}, + error_spec_); +} + +XLA_TEST_F(Bfloat16Test, NegateScalarF16) { + ComputationBuilder builder(client_, TestName()); + builder.Neg(builder.ConstantR0(static_cast(2.1f))); + + ComputeAndCompareR0(&builder, static_cast(-2.1f), {}, + error_spec_); +} + +XLA_TEST_F(Bfloat16Test, BatchNormTraining) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + {{{{static_cast(1.f)}, {static_cast(2.f)}}, + {{static_cast(3.f)}, {static_cast(4.f)}}}, + {{{static_cast(5.f)}, {static_cast(6.f)}}, + {{static_cast(7.f)}, {static_cast(8.f)}}}}); + + auto scale = builder.ConstantR1( + {static_cast(2.0f), static_cast(3.0f)}); + + auto offset = builder.ConstantR1( + {static_cast(1.0f), static_cast(2.0f)}); + + auto tuple = builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR4( + {{{{static_cast(-1.7f)}, {static_cast(-2.04f)}}, + {{static_cast(0.105f)}, {static_cast(0.65f)}}}, + {{{static_cast(1.89f)}, {static_cast(3.35f)}}, + {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) + .get(), + Literal::CreateR1( + {static_cast(4), static_cast(5)}) + .get(), + Literal::CreateR1( + {static_cast(5), static_cast(5)}) + .get()}); + + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); +} + +XLA_TEST_F(Bfloat16Test, BatchNormGrad) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + Array4D(2, 2, 2, 1, static_cast(0.0f))); + + auto scale = builder.ConstantR1( + {static_cast(1.0f), static_cast(1.0f)}); + + auto mean = builder.ConstantR1( + {static_cast(0.0f), static_cast(0.0f)}); + + auto var = builder.ConstantR1( + {static_cast(1.0f), static_cast(1.0f)}); + + auto grad_output = builder.ConstantR4FromArray4D( + {{{{static_cast(1.f)}, {static_cast(2.f)}}, + {{static_cast(3.f)}, {static_cast(4.f)}}}, + {{{static_cast(5.f)}, {static_cast(6.f)}}, + {{static_cast(7.f)}, {static_cast(8.f)}}}}); + + builder.BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR4( + {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, + {{static_cast(-1.f)}, {static_cast(-1.f)}}}, + {{{static_cast(1.f)}, {static_cast(1.f)}}, + {{static_cast(3.f)}, {static_cast(3.f)}}}}) + .get(), + Literal::CreateR1( + {static_cast(0), static_cast(0)}) + .get(), + Literal::CreateR1( + {static_cast(16), static_cast(20)}) + .get()}); + + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0d94d65c1015fb54ada3fdfc95d0c31d0a0f158b --- /dev/null +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class BitcastConvertTest : public ClientLibraryTestBase { + public: + explicit BitcastConvertTest(perftools::gputools::Platform* platform = nullptr) + : ClientLibraryTestBase(platform) { + mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); + mutable_debug_options()->add_xla_disable_hlo_passes("inline"); + } +}; + +TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42, 64}); + builder.BitcastConvertType(a, S32); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.0f, 64.0f}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {42.0f, 64.0f}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({0, static_cast(0x80000000), 0x3F800000, + static_cast(0xBF800000), 0x3F000000, + static_cast(0xBF000000)}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.6, 64.4}); + builder.BitcastConvertType(a, S32); + + std::vector expected = {0x422a6666, 0x4280cccd}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertS32Extremes) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {std::numeric_limits::min(), std::numeric_limits::max()}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {-0.0f, NAN}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0, 0)); +} + +TEST_F(BitcastConvertTest, ConvertMapToS32) { + ComputationBuilder builder(client_, TestName()); + auto b = builder.CreateSubBuilder("convert"); + auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); + b->BitcastConvertType(param, S32); + auto a = builder.ConstantR1({42.0f, 64.0f}); + builder.Map({a}, b->BuildAndNoteError(), {0}); + + std::vector expected = {0x42280000, 0x42800000}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertMapToF32) { + ComputationBuilder builder(client_, TestName()); + auto b = builder.CreateSubBuilder("convert"); + auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); + b->BitcastConvertType(param, F32); + auto a = builder.ConstantR1({0x42280000, 0x42800000}); + builder.Map({a}, b->BuildAndNoteError(), {0}); + + std::vector expected = {42.0f, 64.0f}; + ComputeAndCompareR1(&builder, expected, {}); +} + +// Regression test for b/31758660. When ReshapeMover transforms +// input -> reshape -> convert +// to +// input -> convert -> reshape +// the new convert should have the same element type as the old convert. +TEST_F(BitcastConvertTest, ConvertReshape) { + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantR1({0x42280000}); + auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); + builder.BitcastConvertType(reshape, F32); + + ComputeAndCompareR0(&builder, 42.0f, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 36d10fff5400b78fa3ea9a03f6b9cd73059f1427..610302ac1256a57db6ed6e18016a4136973e3891 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -29,6 +29,7 @@ def xla_test(name, deps, xla_test_library_deps=[], backends=[], + blacklisted_backends=[], args=[], tags=[], copts=[], @@ -92,17 +93,24 @@ def xla_test(name, backends: A list of backends to generate tests for. Supported values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will be generated for all supported backends. + blacklisted_backends: A list of backends to NOT generate tests for. args: Test arguments for the target. tags: Tags for the target. - backend_args: A dict mapping backend name to list of additional args to - use for that target. + copts: Additional copts to pass to the build. + data: Additional data to pass to the build. backend_tags: A dict mapping backend name to list of additional tags to use for that target. + backend_args: A dict mapping backend name to list of additional args to + use for that target. + **kwargs: Additional keyword arguments to pass to native.cc_test. """ test_names = [] if not backends: backends = all_backends + backends = [backend for backend in backends + if backend not in blacklisted_backends] + native.cc_library( name="%s_lib" % name, srcs=srcs, @@ -248,5 +256,6 @@ def generate_backend_test_macros(backends=[]): deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", "//tensorflow/core:test", ]) diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 065bce7e3146c93568bbce2b0e7e23ddddc4ea31..50bf185936808fbd9c49f7fbd5ab0c0b4a76504b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -262,20 +262,39 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( expected.shape().element_type() == PRED) << ShapeUtil::HumanString(expected.shape()); } + // We allow using a float expected literal for a bfloat16 output. In this + // case, we need to convert the expected literal to bfloat16. + const Literal* expected_ptr = &expected; + std::unique_ptr converted_expected; + Shape layout_shape; + if (use_bfloat16_) { + converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + expected_ptr = converted_expected.get(); + if (shape_with_layout != nullptr) { + layout_shape = *shape_with_layout; + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); + shape_with_layout = &layout_shape; + } + } auto expect_equal = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectEqual(expected, actual, error_message); + LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message); }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( - computation, expected, arguments, expect_equal); + computation, *expected_ptr, arguments, expect_equal); } if (execution_options_.debug_options().xla_test_all_input_layouts()) { return ComputeAndCompareLiteralWithAllInputLayouts( - computation, expected, arguments, expect_equal, shape_with_layout); + computation, *expected_ptr, arguments, expect_equal, shape_with_layout); } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectEqual(expected, *actual); + LiteralTestUtil::ExpectEqual(*expected_ptr, *actual); return tensorflow::Status::OK(); } @@ -286,20 +305,39 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + // We allow using a float expected literal for a bfloat16 output. In this + // case, we need to convert the expected literal to bfloat16. + const Literal* expected_ptr = &expected; + std::unique_ptr converted_expected; + Shape layout_shape; + if (use_bfloat16_) { + converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + expected_ptr = converted_expected.get(); + if (shape_with_layout != nullptr) { + layout_shape = *shape_with_layout; + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); + shape_with_layout = &layout_shape; + } + } auto expect_near = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectNear(expected, actual, error, error_message); + LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message); }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { - return ComputeAndCompareLiteralWithAllOutputLayouts(computation, expected, - arguments, expect_near); + return ComputeAndCompareLiteralWithAllOutputLayouts( + computation, *expected_ptr, arguments, expect_near); } if (execution_options_.debug_options().xla_test_all_input_layouts()) { return ComputeAndCompareLiteralWithAllInputLayouts( - computation, expected, arguments, expect_near, shape_with_layout); + computation, *expected_ptr, arguments, expect_near, shape_with_layout); } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectNear(expected, *actual, error); + LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error); return tensorflow::Status::OK(); } @@ -346,10 +384,67 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( LiteralTestUtil::ExpectNearTuple(expected, *actual, error); } +void ClientLibraryTestBase::ComputeAndCompare( + ComputationBuilder* builder, const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments) { + auto status_or_data = ComputeValueAndReference(builder, operand, arguments); + EXPECT_IS_OK(status_or_data); + if (!status_or_data.ok()) { + return; + } + std::unique_ptr reference, result; + std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); + LiteralTestUtil::ExpectEqual(*reference, *result); +} + +void ClientLibraryTestBase::ComputeAndCompare( + ComputationBuilder* builder, const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + auto status_or_data = ComputeValueAndReference(builder, operand, arguments); + EXPECT_IS_OK(status_or_data); + if (!status_or_data.ok()) { + return; + } + std::unique_ptr reference, result; + std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); + LiteralTestUtil::ExpectNear(*reference, *result, error); +} + +StatusOr, std::unique_ptr>> +ClientLibraryTestBase::ComputeValueAndReference( + ComputationBuilder* builder, const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments) { + // Transfer the arguments to the executor service. We put the unique_ptr's + // into a vector to keep the data alive on the service until the end of this + // function. + std::vector> argument_data; + for (const auto& arg : arguments) { + TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg)); + argument_data.push_back(std::move(data)); + } + + // Create raw pointers to the GlobalData for the rest of the call stack. + std::vector argument_data_ptr; + std::transform( + argument_data.begin(), argument_data.end(), + std::back_inserter(argument_data_ptr), + [](const std::unique_ptr& data) { return data.get(); }); + + TF_ASSIGN_OR_RETURN( + auto reference, + builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments)); + TF_ASSIGN_OR_RETURN(auto result, + ExecuteAndTransfer(builder, argument_data_ptr)); + return std::make_pair(std::move(reference), std::move(result)); +} + Computation ClientLibraryTestBase::CreateScalarRelu() { ComputationBuilder builder(client_, "relu"); - auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); - auto zero = builder.ConstantR0(0.0); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto z_value = builder.Parameter(0, shape, "z_value"); + auto zero = use_bfloat16_ + ? builder.ConstantR0(static_cast(0.0f)) + : builder.ConstantR0(0.0f); builder.Max(z_value, zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -358,8 +453,9 @@ Computation ClientLibraryTestBase::CreateScalarRelu() { Computation ClientLibraryTestBase::CreateScalarMax() { ComputationBuilder builder(client_, "max"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto x = builder.Parameter(0, shape, "x"); + auto y = builder.Parameter(1, shape, "y"); builder.Max(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -368,11 +464,12 @@ Computation ClientLibraryTestBase::CreateScalarMax() { Computation ClientLibraryTestBase::CreateScalarReluSensitivity() { ComputationBuilder builder(client_, "relu_sensitivity"); - auto activation = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "activation"); - auto backprop = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "backprop"); - auto zero = builder.ConstantR0(0.0); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto activation = builder.Parameter(0, shape, "activation"); + auto backprop = builder.Parameter(1, shape, "backprop"); + auto zero = use_bfloat16_ + ? builder.ConstantR0(static_cast(0.0f)) + : builder.ConstantR0(0.0f); auto activation_gtz = builder.Gt(activation, zero); builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); @@ -407,4 +504,27 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + +ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( + const Literal& literal, ComputationBuilder* builder) { + return builder->ConstantLiteral( + use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 7cfc276ec19e3b177f87a08e716cb34b7676dd6b..4d0cf8bf71cf22d7c046bb22754a8d4e299ed9db 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -194,7 +194,17 @@ class ClientLibraryTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments); void ComputeAndCompareTuple( ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec abs_error); + tensorflow::gtl::ArraySlice arguments, ErrorSpec error); + + // Convenience method for running a built computation and comparing the result + // with the HloEvaluator. + void ComputeAndCompare(ComputationBuilder* builder, + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments); + void ComputeAndCompare(ComputationBuilder* builder, + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments, + ErrorSpec error); // Create scalar operations for use in reductions. Computation CreateScalarRelu(); @@ -235,51 +245,102 @@ class ClientLibraryTestBase : public ::testing::Test { const int rows, const int cols, const int rows_padded, const int cols_padded); - // Create a parameter instruction that wraps a given value and then stores + // Creates a parameter instruction, transfers the literal for the parameter to + // server, then stores into "data_handle" the global handle for that + // parameter. When the use_bfloat16 flag is set but the literal has F32 + // elements, the literal will be converted to BF16 before being transferred. + std::unique_ptr CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle); + + // Creates a constant instruction with the given literal. When the + // use_bfloat16 flag is set but the literal has F32 elements, the elements + // will be converted to BF16s. + ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, + ComputationBuilder* builder); + + // Creates a constant instruction with the given array. When the use_bfloat16 + // flag is set but the array has float elements, the elements will be + // converted to bfloat16s. + template + ComputationDataHandle CreateConstantFromArray(const Array& array, + ComputationBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + } + + // Same as CreateConstantFromArray, but for scalars. + template + ComputationDataHandle CreateConstantFromScalar(NativeT value, + ComputationBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateR0(value), + builder); + } + + // Creates a parameter instruction that wraps a given value and then stores // into "data_handle" the global handle for that parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given values and then stores + // Creates a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given constant array + // Creates a parameter instruction that wraps the given constant array // "array_2d" and then stores to "data_handle" the global handle for that // parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given constant array + // Creates a parameter instruction that wraps the given constant array // "array_3d" and then stores to "data_handle" the global handle for that // parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); + // Getter and setter for the use_bfloat16 flag, which indicates whether to run + // tests with all float-type input/output converted to bfloat16. + bool use_bfloat16() const { return use_bfloat16_; } + void set_use_bfloat16(bool value) { use_bfloat16_ = value; } + + // The float type used in this test, BF16 or F32 according to use_bfloat16. + PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + Client* client_; ExecutionOptions execution_options_; @@ -298,6 +359,17 @@ class ClientLibraryTestBase : public ::testing::Test { const std::function& verify_output, const Shape* output_with_layout = nullptr); + + // Executes the computation and calculates the expected reference value using + // the HloEvaluator. Returns two literal in the order of (expected, actual). + StatusOr, std::unique_ptr>> + ComputeValueAndReference(ComputationBuilder* builder, + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments); + + // Whether to run tests with all float-type input/output converted to + // bfloat16. + bool use_bfloat16_ = false; }; template @@ -315,8 +387,10 @@ void ClientLibraryTestBase::ComputeAndCompareR0( ComputationBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -338,8 +412,10 @@ void ClientLibraryTestBase::ComputeAndCompareR1( ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -362,6 +438,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -386,6 +463,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -410,6 +488,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -423,6 +502,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -435,6 +517,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -447,6 +532,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -459,6 +547,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -469,8 +560,7 @@ template std::vector ClientLibraryTestBase::CreatePseudorandomR1( const int width, NativeT min_value, NativeT max_value, uint32 seed) { std::vector result(width); - test_utils::PseudorandomGenerator generator(min_value, max_value, - seed); + PseudorandomGenerator generator(min_value, max_value, seed); for (int i = 0; i < width; ++i) { result[i] = generator.get(); } @@ -482,8 +572,7 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( const int rows, const int cols, NativeT min_value, NativeT max_value, uint32 seed) { auto result = MakeUnique>(rows, cols); - test_utils::PseudorandomGenerator generator(min_value, max_value, - seed); + PseudorandomGenerator generator(min_value, max_value, seed); for (int y = 0; y < rows; ++y) { for (int x = 0; x < cols; ++x) { (*result)(y, x) = generator.get(); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 0853feeebd6f7a249cf767e1f8a63675d4bddd27..8853ed9e5780672d4006c326291767b8b5253f56 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -20,10 +20,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -42,26 +44,26 @@ TEST_F(ClientTest, ExecuteWithLayout) { for (const std::vector& transfer_layout : layouts) { b.Add(b.ConstantR2({{1, 2}, {3, 4}}), b.ConstantR2({{10, 20}, {30, 40}})); - auto computation = b.Build(); - ASSERT_TRUE(computation.ok()) << computation.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, execute_layout); - std::unique_ptr data = - client_->Execute(computation.ValueOrDie(), {}, &execution_options) - .ConsumeValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr data, + client_->Execute(computation, {}, &execution_options)); std::unique_ptr expected_literal = - test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, - transfer_layout); + Literal::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); - auto computed = client_->Transfer(*data, &expected_literal->shape()); + TF_ASSERT_OK_AND_ASSIGN( + auto computed, client_->Transfer(*data, &expected_literal->shape())); - LiteralTestUtil::AssertEqualShapesAndLayouts( - expected_literal->shape(), computed.ValueOrDie()->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), + computed->shape()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed); } } } @@ -72,8 +74,7 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), b.ConstantR2({{10, 20}, {30, 40}})}); - auto computation = b.Build(); - ASSERT_TRUE(computation.ok()) << computation.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); ExecutionOptions execution_options = execution_options_; // Create a result shape with one element column major and the other row @@ -85,10 +86,9 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{1, 0})}); - auto result = - client_ - ->ExecuteAndTransfer(computation.ValueOrDie(), {}, &execution_options) - .ConsumeValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, result->tuple_literals(0)); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, @@ -107,5 +107,42 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { /*minor_to_major=*/{1, 0}))); } +TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { + Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg; + Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr const_arg, + client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); + + ComputationBuilder b(client_, TestName() + ".add"); + b.Add(b.Parameter(0, shape, "param_0"), + b.ConstantR2({{1, 2}, {3, 4}})); + TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build()); + + // We can't really test parallel execution on CPU since all of the cores in a + // CPU are presented as a single device. So for now we test "parallel" + // execution on a single device. + std::vector computation_instances; + TF_ASSERT_OK_AND_ASSIGN(std::vector devices, + client_->GetDeviceHandles(1)); + ASSERT_EQ(devices.size(), 1); + + ExecutionOptions options = execution_options_; + *options.add_device_handles() = devices[0]; + computation_instances.push_back(Client::ComputationInstance( + add_with_one_arg, {const_arg.get()}, options, nullptr)); + + TF_ASSERT_OK_AND_ASSIGN(auto results, + client_->ExecuteParallel(computation_instances)); + auto expected_result = Literal::CreateR2({{6, 8}, {10, 12}}); + + TF_ASSERT_OK_AND_ASSIGN( + auto result_literal, + client_->Transfer(*results[0], &expected_result->shape())); + + LiteralTestUtil::ExpectEqual(*expected_result, *result_literal); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index 43ea7f6019415a171123ee0315533b8a3b1ff984..e472408dcf7ed5fec74e886fd0092ce47ee2e7eb 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -19,8 +19,11 @@ namespace xla { StatusOr> CodegenTestBase::CompileToExecutable( std::unique_ptr hlo_module) { - return backend().compiler()->Compile(std::move(hlo_module), - backend().default_stream_executor()); + TF_ASSIGN_OR_RETURN(hlo_module, backend().compiler()->RunHloPasses( + std::move(hlo_module), + backend().default_stream_executor())); + return backend().compiler()->RunBackend(std::move(hlo_module), + backend().default_stream_executor()); } StatusOr> diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 707e439245c29a1ddf80bfd9205aa14b0d4765f6..0f780fa87ef98fd5c48726ef83fa8efc1e90fbf7 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -138,13 +138,13 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { // layouts. Use these arrays as parameters to a simple computation. If the // layout of the array changes then computation should be recompiled (cache // miss). - auto rowmaj_array = test_utils::CreateR2LiteralWithLayout( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0}); + auto rowmaj_array = Literal::CreateR2WithLayout( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); auto rowmaj_handle = client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); - auto colmaj_array = test_utils::CreateR2LiteralWithLayout( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1}); + auto colmaj_array = Literal::CreateR2WithLayout( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); auto colmaj_handle = client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index d423c78476dde18d209b5efac9e8f77da41bfeb4..5226a78386824a94572d3e5cc3329677108a910a 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -264,8 +264,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) { ASSERT_TRUE(computed.ok()) << computed.status(); std::unique_ptr expected_literal = - test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, - layout); + Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, + LayoutUtil::MakeLayout(layout)); LiteralTestUtil::AssertEqualShapesAndLayouts( expected_literal->shape(), computed.ValueOrDie()->shape()); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cbfacaea53952b02596eb3e84b13a5749335651d --- /dev/null +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class ConditionalOpTest : public ClientLibraryTestBase { + protected: + Computation CreateR0F32ConstantComputation(float value) { + ComputationBuilder builder(client_, "Constant"); + builder.Parameter(0, empty_tuple_, "tuple"); + builder.ConstantR0(value); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0F32IdentityComputation() { + ComputationBuilder builder(client_, "Identity"); + builder.Parameter(0, r0f32_, "x"); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0F32CeilComputation() { + ComputationBuilder builder(client_, "Ceil"); + auto param = builder.Parameter(0, r0f32_, "param"); + builder.Ceil(param); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0F32FloorComputation() { + ComputationBuilder builder(client_, "Ceil"); + auto param = builder.Parameter(0, r0f32_, "param"); + builder.Floor(param); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateAddTupleComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + builder.Add(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateAddR0Computation() { + return CreateAddTupleComputation("AddR0", tuple_2_r0f32_); + } + + Computation CreateAddR1Computation() { + return CreateAddTupleComputation("AddR1", tuple_2_r1s2f32_); + } + + Computation CreateSubTupleComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + builder.Sub(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateSubR0Computation() { + return CreateSubTupleComputation("SubR0", tuple_2_r0f32_); + } + + Computation CreateSubR1Computation() { + return CreateSubTupleComputation("SubR1", tuple_2_r1s2f32_); + } + + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); + Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); + Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})}); + Shape empty_tuple_ = ShapeUtil::MakeTupleShape({}); + ErrorSpec error_spec_{0.001}; +}; + +// Test true and false computations that do not take any parameters. +XLA_TEST_F(ConditionalOpTest, Parameters0) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operands = builder.Tuple({}); + auto true_computation = CreateR0F32ConstantComputation(56.0f); + auto false_computation = CreateR0F32ConstantComputation(12.0f); + auto result = builder.Conditional(pred, operands, true_computation, operands, + false_computation); + + ComputeAndCompareR0(&builder, 56.0f, {}, error_spec_); +} + +// Test true and false computations that take in 1 parameter. +XLA_TEST_F(ConditionalOpTest, Parameters1) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto identity = CreateR0F32IdentityComputation(); + auto result = + builder.Conditional(pred, operand1, identity, operand2, identity); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 parameters and predicate is +// true. +XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR0Computation(), + operands, CreateSubR0Computation()); + + ComputeAndCompareR0(&builder, 68.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 parameters and predicate is +// false. +XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR0Computation(), + operands, CreateSubR0Computation()); + + ComputeAndCompareR0(&builder, 44.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 array parameters and +// predicate is true. +XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR1({24.0f, 56.0f}); + auto operand2 = builder.ConstantR1({10.0f, 11.0f}); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR1Computation(), + operands, CreateSubR1Computation()); + + ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {}, error_spec_); +} + +// Test true and false computations that take in 2 array parameters and +// predicate is false. +XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR1({24.0f, 56.0f}); + auto operand2 = builder.ConstantR1({10.0f, 11.0f}); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR1Computation(), + operands, CreateSubR1Computation()); + + ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {}, error_spec_); +} + +// Test the case where one conditional is nested within another. +XLA_TEST_F(ConditionalOpTest, NestedConditionals) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); + ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); + auto pred_cond = inner_builder.GetTupleElement(param0, 0); + auto true_operand = inner_builder.GetTupleElement(param0, 1); + auto false_operand = inner_builder.GetTupleElement(param0, 2); + inner_builder.Conditional(pred_cond, true_operand, + CreateR0F32CeilComputation(), false_operand, + CreateR0F32FloorComputation()); + auto inner_builder_result = inner_builder.Build(); + + ComputationBuilder builder(client_, TestName()); + auto pred1 = builder.ConstantR0(true); + auto pred2 = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(1.1f); + auto operand2 = builder.ConstantR0(12.2f); + auto operand3 = builder.ConstantR0(43.3f); + auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); + builder.Conditional(pred1, tuple_operand, + inner_builder_result.ConsumeValueOrDie(), operand3, + CreateR0F32IdentityComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test a mismatch in the shape of the true operand and true computation. +XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + builder.Conditional(pred, operands, CreateAddR1Computation(), operands, + CreateSubR0Computation()); + + auto result = builder.Build(); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("true_operand must match the shape of the " + "only parameter of true_computation")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index b0a63bccbb93f226175beff2e30e2a243fdca1d3..896b34fb6e2762c14bd9ec2bf1ba13c548d4cf60 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -39,8 +39,8 @@ class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {}; // Tests the convolution operation with invalid input dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 2, 0, 2, 2, 3, 0, 1, 2, - 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0, + 1, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("input are not unique")); @@ -49,13 +49,23 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { // Tests the convolution operation with invalid weight dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 0, 1, 2, 3, 2, 3, 2, - 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0, + 2, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("weight are not unique")); } +// Tests the convolution operation with invalid output dimension numbers. +TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) { + auto dimension_numbers_status = + ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0, + 1, 2, 3); + ASSERT_FALSE(dimension_numbers_status.ok()); + ASSERT_THAT(dimension_numbers_status.status().error_message(), + ::testing::HasSubstr("output are not unique")); +} + XLA_TEST_F(ConvolutionDimensionNumbersTest, TwoConvsWithDifferentDimensionNumbers) { auto input_array = MakeUnique>(2, 3, 5, 5); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 0cc2e5fb7e655884f3334426a684dd3ce00d4052..2924c08615fa706bb19addf04bf58e1d5dd5a659 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -82,177 +82,127 @@ XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) { ComputationBuilder builder(client_, TestName()); auto lhs = builder.ConstantR4FromArray4D(*alhs); auto rhs = builder.ConstantR4FromArray4D(*arhs); - builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid); - - ComputeAndCompareR4(&builder, *aexpected, {}, error_spec_); + ComputeAndCompare(&builder, conv, {}, error_spec_); } TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); - } - - Array4D input(1, 1, 1, 2); - input.FillWithYX(Array2D({ + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D input_data(1, 1, 1, 2); + input_data.FillWithYX(Array2D({ {1, 2}, })); - Array4D filter(1, 1, 1, 2); - filter.FillWithYX(Array2D({ + Array4D filter_data(1, 1, 1, 2); + filter_data.FillWithYX(Array2D({ {5, 6}, })); - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } // Tests valid padding for 2D convolution in raster space. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); - } + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); - Array4D input(1, 1, 4, 4); + Array4D input_data(1, 1, 4, 4); // clang-format off - input.FillWithYX(Array2D({ + input_data.FillWithYX(Array2D({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D filter(1, 1, 2, 2); + Array4D filter_data(1, 1, 2, 2); // clang-format off - filter.FillWithYX(Array2D({ + filter_data.FillWithYX(Array2D({ {5, 6}, {7, 8}, })); // clang-format on - - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } // Tests same padding for 2D convolution in raster space. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); - } + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); - Array4D input(1, 1, 4, 4); + Array4D input_data(1, 1, 4, 4); // clang-format off - input.FillWithYX(Array2D({ + input_data.FillWithYX(Array2D({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D filter(1, 1, 2, 2); + Array4D filter_data(1, 1, 2, 2); // clang-format off - filter.FillWithYX(Array2D({ + filter_data.FillWithYX(Array2D({ {5, 6}, {7, 8}, })); // clang-format on - - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } // Tests same padding for 2D convolution in raster space with an odd sized // kernel. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); - } + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); - Array4D input(1, 1, 4, 4); + Array4D input_data(1, 1, 4, 4); // clang-format off - input.FillWithYX(Array2D({ + input_data.FillWithYX(Array2D({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D filter(1, 1, 3, 3); + Array4D filter_data(1, 1, 3, 3); // clang-format off - filter.FillWithYX(Array2D({ + filter_data.FillWithYX(Array2D({ { 5, 6, 7}, { 8, 9, 10}, {11, 12, 13}, })); // clang-format on - - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { @@ -420,9 +370,12 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); dnums.set_input_feature_dimension(4); dnums.set_output_feature_dimension(4); dnums.add_kernel_spatial_dimensions(0); @@ -473,8 +426,10 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); @@ -508,6 +463,54 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { error_spec_); } +// Test fixture to run convolution tests with and without convolution +// canonicalization enabled. +class ConvolveWithAndWithoutCanonicalization + : public ConvolutionTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, + DISABLED_ON_GPU(Convolve2D_NoSpatialDims)) { + if (GetParam()) { + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "convolution-canonicalization"); + } + ComputationBuilder builder(client_, TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10}); + + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + + ConvolutionDimensionNumbers dnums; + dnums.set_input_feature_dimension(0); + dnums.set_input_batch_dimension(1); + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_output_batch_dimension(0); + dnums.set_output_feature_dimension(1); + auto conv = builder.ConvWithGeneralDimensions(input, filter, {}, + Padding::kValid, dnums); + + Array2D param0(4, 29); + param0.FillUnique(); + + Array2D param1(4, 10); + param1.FillUnique(); + + Array2D expected_result(29, 10); + expected_result.Fill(0); + + ComputeAndCompare( + &builder, conv, + {*Literal::CreateFromArray(param0), *Literal::CreateFromArray(param1)}, + error_spec_); +} + +INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation, + ConvolveWithAndWithoutCanonicalization, + ::testing::Values(true, false)); + struct Convolve1DTestParam { int64 input_feature; int64 output_feature; @@ -540,7 +543,8 @@ XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); dnums.set_input_feature_dimension(2); dnums.set_output_feature_dimension(2); dnums.add_kernel_spatial_dimensions(0); diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 9b36e3722b8f8a5d01c426425fdfb0c4b9ae3a16..9c1145def8c11f1222c63adf006102887d49f00d 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -320,9 +320,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); - const Array4D filter_array(1, 1, 3, 3, {10000, 0, 1000, // row 0 - 0, 100, 0, // row 1 - 10, 0, 1}); // row 2 + const Array4D filter_array(1, 1, 3, 3, + {10000, 0, 1000, // row 0 + 0, 100, 0, // row 1 + 10, 0, 1}); // row 2 auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kSame); @@ -472,7 +473,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { builder.Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { - 23, 33, 43, + 23, + 33, + 43, }; Array4D expected(bs, 1, 1, 1, expected_data); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -669,10 +672,11 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 1, 3, 4, input_data); - Array4D filter_array(1, 1, 4, 3, {100, 10, 1, // - 200, 20, 2, // - 300, 30, 3, // - 400, 40, 4}); + Array4D filter_array(1, 1, 4, 3, + {100, 10, 1, // + 200, 20, 2, // + 300, 30, 3, // + 400, 40, 4}); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.ConvGeneralDilated( @@ -681,9 +685,10 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { /*rhs_dilation=*/{}, ComputationBuilder::CreateDefaultConvDimensionNumbers()); - Array4D expected(1, 1, 3, 5, {204, 40, 406, 60, 608, // - 1518, 180, 1821, 210, 2124, // - 4146, 460, 4651, 510, 5156}); + Array4D expected(1, 1, 3, 5, + {204, 40, 406, 60, 608, // + 1518, 180, 1821, 210, 2124, // + 4146, 460, 4651, 510, 5156}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } @@ -926,7 +931,8 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { ComputeAndCompareR4(&builder, *expected, {}, error_spec_); } -XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x16x16_Filter16x16x16x16) { +XLA_TEST_F(ConvolutionVariantsTest, + RandomData_Input16x16x16x16_Filter16x16x16x16) { constexpr int bs = 16; constexpr int iz = 16; constexpr int oz = 16; @@ -976,8 +982,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1018,8 +1026,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1060,8 +1070,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1099,8 +1111,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1131,7 +1145,8 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { // Conv([1,2,3], Reverse([5,6]), padding_low=1) // into // BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1) -XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) { +XLA_TEST_F(ConvolutionVariantsTest, + BackwardInputLowPaddingLessThanHighPadding) { ComputationBuilder builder(client_, TestName()); auto gradients = builder.ConstantR4FromArray4D( @@ -1149,7 +1164,8 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) // Conv([1], Reverse([1,10,100]), padding_high=3, base_dilation=3) // into // BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1)) -XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) { +XLA_TEST_F(ConvolutionVariantsTest, + BackwardInputLowPaddingGreaterThanHighPadding) { ComputationBuilder builder(client_, TestName()); auto gradients = builder.ConstantR4FromArray4D( @@ -1206,7 +1222,8 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { ComputeAndCompareR4(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_); } -XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) { +XLA_TEST_F(ConvolutionVariantsTest, + BackwardFilterLowPaddingLessThanHighPadding) { ComputationBuilder builder(client_, TestName()); // activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0 @@ -1230,7 +1247,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) } XLA_TEST_F(ConvolutionVariantsTest, - BackwardFilterLowPaddingGreaterThanHighPadding) { + BackwardFilterLowPaddingGreaterThanHighPadding) { ComputationBuilder builder(client_, TestName()); // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4 diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index cf089d748dcd4f5db637ff9087c5fbc504c82572..bb7af4c4b837198dccad116367bccbfb7e134901 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -53,6 +53,8 @@ class DotOperationTest : public ClientLibraryTestBase { bool rhs_row_major = false); void TestMatrixDot(int M, int K, int N, bool lhs_row_major = false, bool rhs_row_major = false); + void TestMatrixDotWithAdd(int M, int K, int N, bool dot_lhs_row_major, + bool dot_rhs_row_major, bool addend_row_major); }; XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { @@ -177,15 +179,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 2.0}, {3.0, -4.0}}, - MinorToMajorForIsRowMajor(lhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 6.0}, {7.0, -4.0}}, - MinorToMajorForIsRowMajor(rhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -229,6 +231,54 @@ void DotOperationTest::TestMatrixDot(int M, int K, int N, bool lhs_row_major, ErrorSpec(0.3, 3e-3)); } +void DotOperationTest::TestMatrixDotWithAdd(int M, int K, int N, + bool dot_lhs_row_major, + bool dot_rhs_row_major, + bool addend_row_major) { + std::unique_ptr> dot_lhs_data = + MakeLinspaceArray2D(0.0, 1.0, M, K); + std::unique_ptr dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( + *dot_lhs_data, + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(dot_lhs_row_major))); + auto dot_lhs_handle = + client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); + + std::unique_ptr> dot_rhs_data = + MakeLinspaceArray2D(0.0, 1.0, K, N); + std::unique_ptr dot_rhs_lit = Literal::CreateR2FromArray2DWithLayout( + *dot_rhs_data, + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(dot_rhs_row_major))); + auto dot_rhs_handle = + client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); + + std::unique_ptr> addend_data = + MakeLinspaceArray2D(0.0, 1.0, M, N); + std::unique_ptr addend_lit = Literal::CreateR2FromArray2DWithLayout( + *addend_data, + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(addend_row_major))); + auto addend_handle = + client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Add( + builder.Dot(builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {M, K}), + "dot_lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {K, N}), + "dot_rhs")), + builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {M, N}), "addend")); + + std::unique_ptr> expected = ReferenceUtil::ApplyElementwise2D( + std::plus(), + *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data), + *addend_data); + + ComputeAndCompareR2( + &builder, *expected, + {dot_lhs_handle.get(), dot_rhs_handle.get(), addend_handle.get()}, + ErrorSpec(0.3, 3e-3)); +} + XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTF) { TestMatrixDot(12, 117, 7, true, false); } @@ -277,10 +327,154 @@ XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) { TestMatrixDot(260, 3, 520, false, false); } +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x8) { + TestMatrixDot(1, 8, 8, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x130x8) { + TestMatrixDot(1, 130, 8, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x130) { + TestMatrixDot(1, 8, 130, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x290x130) { + TestMatrixDot(1, 290, 130, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_2x1x1) { + TestMatrixDot(2, 1, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_8x8x1) { + TestMatrixDot(8, 8, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x1x1) { + TestMatrixDot(16, 1, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x3x1) { + TestMatrixDot(16, 3, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_3x3x1) { + TestMatrixDot(3, 3, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_29x29x1) { + TestMatrixDot(29, 29, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x2) { + TestMatrixDot(1, 8, 2, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x2x8) { + TestMatrixDot(1, 2, 8, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1) { + TestMatrixDot(259, 258, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1_FT) { + TestMatrixDot(259, 258, 1, false, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_1x8x8) { + TestMatrixDotWithAdd(1, 8, 8, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_1x130x8) { + TestMatrixDotWithAdd(1, 130, 8, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_1x8x130) { + TestMatrixDotWithAdd(1, 8, 130, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_1x290x130) { + TestMatrixDotWithAdd(1, 290, 130, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_2x1x1) { + TestMatrixDotWithAdd(2, 1, 1, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_8x8x1) { + TestMatrixDotWithAdd(8, 8, 1, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_16x1x1) { + TestMatrixDotWithAdd(16, 1, 1, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_16x3x1) { + TestMatrixDotWithAdd(16, 3, 1, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_3x3x1) { + TestMatrixDotWithAdd(3, 3, 1, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_29x29x1) { + TestMatrixDotWithAdd(29, 29, 1, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_1x8x2) { + TestMatrixDotWithAdd(1, 8, 2, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_1x2x8) { + TestMatrixDotWithAdd(1, 2, 8, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_259x258x1) { + TestMatrixDotWithAdd(259, 258, 1, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_259x258x1_FTT) { + TestMatrixDotWithAdd(259, 258, 1, /*dot_lhs_row_major=*/false, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_259x258x1_FFT) { + TestMatrixDotWithAdd(259, 258, 1, /*dot_lhs_row_major=*/false, + /*dot_rhs_row_major=*/false, /*addend_row_major=*/true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_259x258x1_FFF) { + TestMatrixDotWithAdd(259, 258, 1, /*dot_lhs_row_major=*/false, + /*dot_rhs_row_major=*/false, /*addend_row_major=*/false); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_259x258x1_TFF) { + TestMatrixDotWithAdd(259, 258, 1, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/false, /*addend_row_major=*/false); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotWithAddF32_259x258x1_TTF) { + TestMatrixDotWithAdd(259, 258, 1, /*dot_lhs_row_major=*/true, + /*dot_rhs_row_major=*/true, /*addend_row_major=*/false); +} + XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = false; - TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestSquareMatrixDot(false, false); } XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { @@ -291,10 +485,24 @@ XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { TestSquareMatrixDot(true, false); } -TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = true; - TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { + TestSquareMatrixDot(true, true); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) { + TestSquareMatrixDot(false, false); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) { + TestSquareMatrixDot(false, true); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) { + TestSquareMatrixDot(true, false); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) { + TestSquareMatrixDot(true, true); } XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) { @@ -306,15 +514,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}, - MinorToMajorForIsRowMajor(lhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}, - MinorToMajorForIsRowMajor(rhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -330,35 +538,64 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = false; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(false, false); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = true; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(false, true); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = false; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(true, false); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = true; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(true, true); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { TestNonsquareMatrixDot(); } -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) { - TestNonsquareMatrixDot(); +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) { + TestNonsquareMatrixDot(false, false); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) { + TestNonsquareMatrixDot(false, true); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) { + TestNonsquareMatrixDot(true, false); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) { + TestNonsquareMatrixDot(true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorC64) { + auto lhs_handle = + client_ + ->TransferToServer(*Literal::CreateR2WithLayout( + {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer(*Literal::CreateR2WithLayout( + {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, + LayoutUtil::MakeLayout({1, 0}))) + .ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); + + Array2D expected({{30.0, -2.0}}); + + ComputeAndCompareR2( + &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } XLA_TEST_F(DotOperationTest, ConcurrentMatMul) { @@ -476,5 +713,95 @@ TEST_F(DotOperationTest, TransposeFolding) { } } +TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { + auto prim_type = primitive_util::NativeToPrimitiveType(); + + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + + ComputationBuilder builder(client_, TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), + "rhs_arg_0"); + auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), + "rhs_arg_1"); + auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}), + "rhs_arg_2"); + auto result = builder.Dot( + lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); + + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})); + std::unique_ptr> arg_2_value_array( + new Array2D({{1.0, 2.0}})); + + TF_ASSERT_OK_AND_ASSIGN( + auto arg_0_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_1_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_2_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); + + Array2D expected({{53.0, 74.0}, {45.0, 66.0}}); + ComputeAndCompareR2( + &builder, expected, + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); +} + +TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstRHS) { + auto prim_type = primitive_util::NativeToPrimitiveType(); + + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + + ComputationBuilder builder(client_, TestName()); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), + "lhs_arg_0"); + auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}), + "lhs_arg_1"); + auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}), + "lhs_arg_2"); + auto result = builder.Dot( + builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); + + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + std::unique_ptr> arg_2_value_array( + new Array2D({{1.0}, {2.0}})); + + TF_ASSERT_OK_AND_ASSIGN( + auto arg_0_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_1_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_2_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); + + Array2D expected({{38.0, 36.0}, {93.0, 91.0}}); + ComputeAndCompareR2( + &builder, expected, + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index ab8047c7480f43ba1fd7ca3ad22448e0dd890089..8baaf39e3cf8fa7f6fa4a0224c1297f82e0d92aa 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -559,7 +559,11 @@ void BM_DynamicSlice(int num_iters) { auto computation = builder.Build().ConsumeValueOrDie(); // Initialize and transfer parameter buffer. - auto buffer = ScopedShapedBuffer::Allocate(start_indices_shape, &allocator, 0) + auto shape_size_fn = [client](const Shape& shape) { + return client->backend().transfer_manager()->GetByteSizeRequirement(shape); + }; + auto buffer = ScopedShapedBuffer::Allocate(start_indices_shape, &allocator, 0, + shape_size_fn) .ConsumeValueOrDie(); auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index a8f6488996087b57e3121ce2c7de918070950c72..2686afccc216095345dbb7b43e916fbbe7c8ea39 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -770,8 +770,6 @@ void BM_ParallelFusion(int num_iters) { auto client = ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); - auto* transfer_manager = - TransferManager::GetForPlatform(platform).ValueOrDie(); int device_ordinal = client->default_device_ordinal(); // Computation shape parameters. @@ -796,29 +794,23 @@ void BM_ParallelFusion(int num_iters) { auto computation = builder.Build().ConsumeValueOrDie(); // Transfer literals to device. - auto buffer0 = - ScopedShapedBuffer::Allocate(shape0, &allocator, /*device_ordinal=*/0) - .ConsumeValueOrDie(); auto param0_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *param0_literal, buffer0->mutable_buffer({}))); - - auto buffer1 = - ScopedShapedBuffer::Allocate(shape1, &allocator, /*device_ordinal=*/0) + std::unique_ptr buffer0 = + client->LiteralToShapedBuffer(*param0_literal, device_ordinal) .ConsumeValueOrDie(); + auto param1_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *param1_literal, buffer1->mutable_buffer({}))); - - auto buffer2 = - ScopedShapedBuffer::Allocate(shape2, &allocator, /*device_ordinal=*/0) + std::unique_ptr buffer1 = + client->LiteralToShapedBuffer(*param1_literal, device_ordinal) .ConsumeValueOrDie(); + auto param2_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *param2_literal, buffer2->mutable_buffer({}))); + std::unique_ptr buffer2 = + client->LiteralToShapedBuffer(*param2_literal, device_ordinal) + .ConsumeValueOrDie(); // Build executable. std::unique_ptr executable = @@ -828,7 +820,7 @@ void BM_ParallelFusion(int num_iters) { ExecutableBuildOptions()) .ConsumeValueOrDie(); - se::Stream stream(executors[client->default_device_ordinal()]); + se::Stream stream(executors[device_ordinal]); stream.Init(); // Initialize thread pool. diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d73c05ff92578209143e0679558848160cae99bd..2b38f9c7192066a4124a366da9439e72c79f339e 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -15,13 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include #include #include #include +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -30,18 +39,72 @@ namespace se = ::perftools::gputools; namespace xla { +namespace { + +using tensorflow::StringPiece; +using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::optional; + +constexpr char kInterpreter[] = "interpreter"; + +// Helper functions to get test and reference platforms. +se::Platform* GetReferencePlatform() { + auto result = PlatformUtil::GetPlatform(kInterpreter); + TF_CHECK_OK(result.status()) << "could not get interpreter platform"; + return result.ValueOrDie(); +} + +se::Platform* GetTestPlatform() { + auto result = PlatformUtil::GetDefaultPlatform(); + TF_CHECK_OK(result.status()) << "could not get test platform"; + return result.ValueOrDie(); +} + +bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) { + if (lhs.parameters_size() != rhs.parameters_size()) { + return false; + } + for (int i = 0; i < lhs.parameters_size(); i++) { + if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) { + return false; + } + } + return ShapeUtil::Equal(lhs.result(), rhs.result()); +} + +ProgramShape GetProgramShapeWithLayout(const HloModule& module) { + ProgramShape program_shape; + const auto* entry = module.entry_computation(); + for (const auto* param : entry->parameter_instructions()) { + *program_shape.add_parameters() = param->shape(); + *program_shape.add_parameter_names() = param->name(); + } + *program_shape.mutable_result() = entry->root_instruction()->shape(); + return program_shape; +} + +} // namespace + +HloTestBase::HloTestBase() + : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {} + +HloTestBase::HloTestBase(se::Platform* test_platform, + se::Platform* reference_platform) + : test_runner_(test_platform), reference_runner_(reference_platform) {} + /* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); +} +/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); - - config.set_debug_options(debug_options); - - return MakeUnique(TestName(), VersionedComputationHandle(), - config); + return debug_options; } StatusOr HloTestBase::Execute( @@ -49,25 +112,172 @@ StatusOr HloTestBase::Execute( tensorflow::gtl::ArraySlice arguments, Shape* result_shape) { - return runner_.Execute(std::move(module), arguments, result_shape); + return test_runner_.Execute(std::move(module), arguments, result_shape); } se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { - return runner_.TransferToDevice(literal).ValueOrDie(); + return test_runner_.TransferToDevice(literal).ValueOrDie(); } std::unique_ptr HloTestBase::TransferFromDevice( const Shape& shape, se::DeviceMemoryBase device_base) { - return runner_.TransferFromDevice(shape, device_base).ValueOrDie(); + return test_runner_.TransferFromDevice(shape, device_base).ValueOrDie(); } std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { - return runner_.ExecuteAndTransfer(std::move(module), arguments).ValueOrDie(); + return test_runner_.ExecuteAndTransfer(std::move(module), arguments) + .ValueOrDie(); +} + +StatusOr> HloTestBase::MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor) { + std::unique_ptr reference_module = test_module.Clone(); + const auto& program_shape = GetProgramShapeWithLayout(test_module); + + if (reference_preprocessor != nullptr) { + reference_preprocessor(reference_module.get()); + if (!ProgramShapesEqual(program_shape, + GetProgramShapeWithLayout(*reference_module))) { + return InvalidArgument( + "reference preprocessor must not modify the program shape"); + } + } + TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(), + reference_module.get())); + return std::move(reference_module); +} + +template +StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor) { + static_assert( + std::is_same::value || + std::is_same, LiteralPtr>::value, + "The LiteralPtr type only accepts Literal* or std::unique_ptr."); + TF_RETURN_IF_ERROR( + VerifyHloModule(*test_runner_.backend().platform(), module.get())); + TF_ASSIGN_OR_RETURN(auto reference_module, + MakeReferenceModule(*module, reference_preprocessor)); + + // Execute on two backends. + TF_ASSIGN_OR_RETURN( + auto test, + test_runner_.Execute(std::move(module), arguments, run_hlo_passes)); + TF_ASSIGN_OR_RETURN(auto reference, + reference_runner_.Execute(std::move(reference_module), + arguments, run_hlo_passes)); + return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, + error); +} + +template +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/true, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +template +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/false, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompare>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompareNoHloPasses>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); } -Backend& HloTestBase::backend() { return runner_.backend(); } +Backend& HloTestBase::backend() { return test_runner_.backend(); } /* static */ string HloTestBase::TestName() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 7f068dce36be3546298de2f06bf6d33446d07ca2..3cbbb7aa247dda3e5b6589a2a6aa74cf074babe7 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -24,31 +24,74 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" namespace xla { -// A base class for tests which build and run HLO code. This is a lower level of -// abstraction than using the client interface and enables, for one, explicitly -// building a graph of HLO instructions to run. +// A base class for tests which build and/or run HLO code. The class includes +// support for running an HLO module on two platforms and compare the results. +// This is a lower level of abstraction than using the client interface and +// enables, for one, explicitly building a graph of HLO instructions to run. +// +// This can also be used to write text/file-based test cases. Note that the test +// target is responsible for linking the needed backends. A covenient way to do +// this is to make it an xla_test: it will generate test targets linking with +// the respective backends, which will be used as the test backend; the +// interpreter backend is already linked with hlo_test_base so it will be the +// default reference backend. For example, if you want to compare both cpu vs. +// interpreter, and gpu vs. interpreter, you can: +// +// xla_test ( +// name = "sample_text_test", +// srcs = ["sample_text_test.cc"], +// backends = [ +// "cpu", +// "gpu", +// ], +// deps = [ +// "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", +// ... +// ], +// ) +// +// For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { protected: - HloTestBase() {} + // This uses the interpreter backend as the reference backend and + // automatically finds another supported backend as the test backend. If the + // interpreter is the only supported backend, it will be both the test backend + // and the reference backend. + HloTestBase(); + + // If your test doesn't use interpreter as the reference backend, you can use + // this constructor. Note that your test target is responsible for linking in + // both needed backends. + HloTestBase(::perftools::gputools::Platform* test_platform, + ::perftools::gputools::Platform* reference_platform); ~HloTestBase() override {} // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. It's recommended to use this method to - // create all HloModules for tests. + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. static std::unique_ptr CreateNewModule(); + // Populates debug options from command-line flags and adjusts the options for + // testing. It is recommended to use this when you need to pass in + // DebugOptions, e.g. when creating a module from a string or a file. + static DebugOptions GetDebugOptionsForTest(); + // Executes the given module and returns a global data handle. StatusOr Execute( std::unique_ptr module, @@ -71,6 +114,73 @@ class HloTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments); + // Executes the given hlo module on two backends and compares results. + // + // 'arguments': the input of the hlo module. The LiteralPtr type accepts + // Literal* or std::unique_ptr. + // + // 'error': if has value, expects the results to be near (within the error + // bound). Otherwise, expects the results to be equal. + // + // 'reference_preprocessor': the module should be ready to run on the test + // backend, but it might need to be tailored so that it is able to run on the + // reference backend. Note that the program shape of the module must not be + // modified. + template + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Same as above, except that the module will be executed without Hlo + // optimization. + template + ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Executes an hlo module with fake inputs and compares the results. + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Same as above, except that the module will be executed without Hlo + // optimization. + ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Convenient wrappers for executing and comparing an hlo module with fake + // input. Module can be passed in directly, or parsed from an hlo_string, + // or loaded from a file. + ::testing::AssertionResult RunAndCompare( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPasses( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + // Convenience method to force the layout of a given parameter in a module. // The layout of parameter number 'param_no' in the 'module' is set to // 'layout'. @@ -101,12 +211,31 @@ class HloTestBase : public ::testing::Test { static string TestName(); - // Returns the backend owned by the HloRunner. + // Returns the backend owned by the test runner. Backend& backend(); - HloRunner runner_; + HloRunner test_runner_; + HloRunner reference_runner_; ErrorSpec error_spec_{0.0001}; + + private: + // Given the test module, makes a reference module that is ready to run on the + // reference platform. This assumes that the given module is ready to run on + // the test platform. + StatusOr> MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor); + + // Runs the module on two platforms with or without running hlo passes and + // compares the results. Returns whether the results are near or equal. If any + // error happens before the results are computed, returns the error status. + template + StatusOr<::testing::AssertionResult> RunAndCompareInternal( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/isolated_convolution.hlo b/tensorflow/compiler/xla/tests/isolated_convolution.hlo new file mode 100644 index 0000000000000000000000000000000000000000..9452780930efbb1ecc13b35cd4ab53678d36c37f --- /dev/null +++ b/tensorflow/compiler/xla/tests/isolated_convolution.hlo @@ -0,0 +1,8 @@ +HloModule convolution.167: + +ENTRY %convolution.167 (parameter.0: f32[16,28,28,128], parameter.1: f32[3,3,128,128]) -> f32[16,28,28,128] { + %parameter.0 = f32[16,28,28,128]{3,0,2,1} parameter(0) + %parameter.1 = f32[3,3,128,128]{3,2,1,0} parameter(1) + ROOT %convolution.167 = f32[16,28,28,128]{3,0,2,1} convolution(f32[16,28,28,128]{3,0,2,1} %parameter.0, f32[3,3,128,128]{3,2,1,0} %parameter.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01oi->b01f +} + diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 95a52ecd2f5cfc97ec1ccba7d1b7ca6257a8267e..bf6631a4310d3504e4dfa8c46bf66125a94b9315 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -100,6 +100,58 @@ namespace xla { ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); } +/* static */ std::unique_ptr LiteralTestUtil::ConvertBF16ToF32( + const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector> converted_elements; + for (const auto& element : literal.tuple_literals()) { + converted_elements.push_back(ConvertBF16ToF32(element)); + } + return Literal::MakeTupleOwned(std::move(converted_elements)); + } + + if (literal.shape().element_type() != BF16) { + return MakeUnique(literal); + } + Shape converted_shape = literal.shape(); + converted_shape.set_element_type(F32); + auto converted = Literal::CreateFromShape(converted_shape); + if (!ShapeUtil::HasZeroElements(converted_shape)) { + std::vector index(converted_shape.dimensions_size(), 0); + do { + converted->Set(index, + static_cast(literal.Get(index))); + } while (IndexUtil::BumpIndices(converted_shape, &index)); + } + return converted; +} + +/* static */ std::unique_ptr LiteralTestUtil::ConvertF32ToBF16( + const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector> converted_elements; + for (const auto& element : literal.tuple_literals()) { + converted_elements.push_back(ConvertF32ToBF16(element)); + } + return Literal::MakeTupleOwned(std::move(converted_elements)); + } + + if (literal.shape().element_type() != F32) { + return MakeUnique(literal); + } + Shape converted_shape = literal.shape(); + converted_shape.set_element_type(BF16); + auto converted = Literal::CreateFromShape(converted_shape); + if (!ShapeUtil::HasZeroElements(converted_shape)) { + std::vector index(converted_shape.dimensions_size(), 0); + do { + converted->Set( + index, static_cast(literal.Get(index))); + } while (IndexUtil::BumpIndices(converted_shape, &index)); + } + return converted; +} + namespace { string Hostname() { @@ -116,16 +168,18 @@ template ::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { auto ulhs = tensorflow::bit_cast(lhs); auto urhs = tensorflow::bit_cast(rhs); + auto lhs_double = static_cast(lhs); + auto rhs_double = static_cast(rhs); if (ulhs != urhs) { return ::testing::AssertionFailure() << tensorflow::strings::Printf( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a", tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs)) .c_str(), - lhs, lhs, + lhs_double, lhs_double, tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs)) .c_str(), - rhs, rhs); + rhs_double, rhs_double); } return ::testing::AssertionSuccess(); } @@ -149,6 +203,10 @@ template // Specializations for floating types that do bitwise comparisons when equality // comparison is requested. template <> +::testing::AssertionResult CompareEqual(bfloat16 lhs, bfloat16 rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> ::testing::AssertionResult CompareEqual(float lhs, float rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } @@ -238,6 +296,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, case U64: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; + case BF16: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; case F32: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; @@ -272,23 +333,37 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, return result; } -/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, - const Literal& actual) { +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualTuple( + const Literal& expected, const Literal& actual) { VLOG(1) << "expected: " << expected.ToString(); VLOG(1) << "actual: " << actual.ToString(); - ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape())); - ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape())); + if (!ShapeUtil::IsTuple(expected.shape()) || + !ShapeUtil::IsTuple(actual.shape())) { + return ::testing::AssertionFailure() + << "tuples expected shape = " << expected.shape().ShortDebugString() + << " actual shape = " << actual.shape().ShortDebugString(); + } AssertEqualShapes(expected.shape(), actual.shape()); for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) { const auto& expected_element = expected.tuple_literals(i); const auto& actual_element = actual.tuple_literals(i); if (ShapeUtil::IsTuple(expected_element.shape())) { - ExpectEqualTuple(expected_element, actual_element); + auto ret = EqualTuple(expected_element, actual_element); + if (!ret) { + return ret; + } } else { - ExpectEqual(expected_element, actual_element); + return Equal(expected_element, actual_element); } } + + return ::testing::AssertionSuccess(); +} + +/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, + const Literal& actual) { + EXPECT_TRUE(EqualTuple(expected, actual)); } namespace { @@ -331,6 +406,9 @@ class NearComparator { multi_index_.resize(expected.shape().dimensions_size(), 0); switch (expected.shape().element_type()) { + case BF16: + ExpectLiteralsNear(expected, actual, 0); + break; case F32: ExpectLiteralsNear(expected, actual, 0); break; @@ -516,6 +594,13 @@ void NearComparator::ExpectNear(complex64 expected, complex64 actual, << message; } +template <> +bool NearComparator::ExpectValuesNear(bfloat16 expected, + bfloat16 actual) { + return ExpectValuesNear(static_cast(expected), + static_cast(actual)); +} + } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( @@ -544,8 +629,7 @@ void NearComparator::ExpectNear(complex64 expected, complex64 actual, if (!ShapeUtil::IsTuple(expected.shape()) || !ShapeUtil::IsTuple(actual.shape())) { return ::testing::AssertionFailure() - << "tuples expected expected shape = " - << expected.shape().ShortDebugString() + << "tuples expected shape = " << expected.shape().ShortDebugString() << " actual shape = " << actual.shape().ShortDebugString(); } AssertEqualShapes(expected.shape(), actual.shape()); @@ -579,6 +663,32 @@ void NearComparator::ExpectNear(complex64 expected, complex64 actual, EXPECT_TRUE(NearTuple(expected, actual, error)); } +/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + bool is_tuple = ShapeUtil::IsTuple(expected.shape()); + if (error.has_value()) { + if (is_tuple) { + VLOG(1) << "Expects near tuple"; + return NearTuple(expected, actual, *error); + } + VLOG(1) << "Expects near"; + return Near(expected, actual, *error); + } + if (is_tuple) { + VLOG(1) << "Expects equal tuple"; + return EqualTuple(expected, actual); + } + VLOG(1) << "Expects equal"; + return Equal(expected, actual); +} + +/*static*/ void LiteralTestUtil::ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + EXPECT_TRUE(NearOrEqual(expected, actual, error)); +} + /* static */ string LiteralTestUtil::MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index) { return tensorflow::strings::StrCat( diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 467d44b857b74d2a38e9b3f8a32a9b1d39a4a10d..f53553c70170bdcda717e72ffd791016effd0774 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -59,6 +60,16 @@ class LiteralTestUtil { static void AssertEqualShapesAndLayouts(const Shape& expected, const Shape& actual); + // If the given literal's data type is bfloat16, converts it to a float + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertBF16ToF32(const Literal& bf16_literal); + + // If the given literal's data type is float, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertF32ToBF16(const Literal& f32_literal); + // Asserts that the expected and actual literals are (bitwise) equal for all // elements in the literal. Also, asserts that the rank, dimensions sizes, and // primitive type are equal. @@ -100,6 +111,10 @@ class LiteralTestUtil { static void ExpectR4EqualArray4D(const Array4D& expected, const Literal& actual); + // Returns whether the two tuples are equal. + static ::testing::AssertionResult EqualTuple( + const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; + // Expects that the values of the elements in the expected and actual tuples // are equal. Tuples are matched recursively. static void ExpectEqualTuple(const Literal& expected, const Literal& actual); @@ -167,6 +182,19 @@ class LiteralTestUtil { static void ExpectNearTuple(const Literal& expected, const Literal& actual, const ErrorSpec& error); + // If the error spec is given, returns whether the expected and the actual are + // within the error bound; otherwise, returns whether they are equal. Tuples + // will be compared recursively. + static ::testing::AssertionResult NearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; + + // If the error spec is given, expects the expected and the actual to be near; + // otherwise, expects them to be equal. Tuples will be compared recursively. + static void ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error); + // Returns a multi-dimensional index as a string. For example: '{7, 8}' will // be returned for a 2-dimensional index with dimension 0 index equal to 7, // dimension 1 equal to 8. diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 458258e7ee1fee6964275c51ef38de5ff2ccd7b1..b5b95967ff9162301a092f3a57996e0f3f78658f 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -14,50 +14,147 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { namespace { -class LLVMCompilerTest : public HloTestBase {}; - -XLA_TEST_F(LLVMCompilerTest, CompilerHooks) { - int pre_opt_hook_call_count = 0; - int post_opt_hook_call_count = 0; - - auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) { - ++pre_opt_hook_call_count; - return Status::OK(); - }; - auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) { - ++post_opt_hook_call_count; - return Status::OK(); - }; - - // Create HLO module, and run the compiler. - auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - - auto hlo_module = CreateNewModule(); - hlo_module->AddEntryComputation(builder.Build()); - - auto compiler = static_cast(backend().compiler()); - compiler->SetPreOptimizationHook(pre_opt_hook); - compiler->SetPostOptimizationHook(post_opt_hook); - - ASSERT_TRUE( - compiler - ->Compile(std::move(hlo_module), backend().default_stream_executor()) - .ok()); - - // Test that hooks were called. - EXPECT_EQ(1, pre_opt_hook_call_count); - EXPECT_EQ(1, post_opt_hook_call_count); +class LLVMCompilerTest : public ::testing::Test { + public: + void SetUp() override { + Platform *platform = FindPlatform(); + ASSERT_NE(platform, nullptr); + + BackendOptions backend_options; + backend_options.set_platform(platform); + StatusOr> backend_or_status = + Backend::CreateBackend(backend_options); + ASSERT_IS_OK(backend_or_status.status()); + backend_ = backend_or_status.ConsumeValueOrDie(); + } + + ~LLVMCompilerTest() override {} + + protected: + using Platform = ::perftools::gputools::Platform; + + explicit LLVMCompilerTest(string platform_name) + : platform_name_(std::move(platform_name)) {} + + void TestCompilerHooks(LLVMCompiler *compiler) { + int pre_opt_hook_call_count = 0; + int post_opt_hook_call_count = 0; + + auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) { + ++pre_opt_hook_call_count; + return Status::OK(); + }; + auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) { + ++post_opt_hook_call_count; + return Status::OK(); + }; + + // Create HLO module, and run the compiler. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + compiler->SetPreOptimizationHook(pre_opt_hook); + compiler->SetPostOptimizationHook(post_opt_hook); + + ASSERT_TRUE(compiler + ->RunBackend(std::move(hlo_module), + backend_->default_stream_executor()) + .ok()); + + // Test that hooks were called. + EXPECT_EQ(1, pre_opt_hook_call_count); + EXPECT_EQ(1, post_opt_hook_call_count); + } + + void TestMultiModuleCompilation(LLVMCompiler *compiler) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + + std::unique_ptr hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + std::vector> modules; + modules.push_back(hlo_module->Clone()); + modules.push_back(std::move(hlo_module)); + + std::vector> executors; + executors.push_back({backend_->default_stream_executor()}); + executors.push_back({backend_->default_stream_executor()}); + + EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors))); + } + + private: + Platform *FindPlatform() { + for (Platform *platform : + PlatformUtil::GetSupportedPlatforms().ConsumeValueOrDie()) { + if (platform->Name() == platform_name_) { + return platform; + } + } + return nullptr; + } + + string platform_name_; + std::unique_ptr backend_; + + static string TestName() { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + static std::unique_ptr CreateNewModule() { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); + } +}; + +class CpuCompilerTest : public LLVMCompilerTest { + public: + CpuCompilerTest() : LLVMCompilerTest("Host") {} +}; + +class GpuCompilerTest : public LLVMCompilerTest { + public: + GpuCompilerTest() : LLVMCompilerTest("CUDA") {} +}; + +TEST_F(CpuCompilerTest, HooksTest) { + cpu::CpuCompiler compiler; + TestCompilerHooks(&compiler); +} + +TEST_F(GpuCompilerTest, HooksTest) { + gpu::GpuCompiler compiler; + TestCompilerHooks(&compiler); } +TEST_F(CpuCompilerTest, MultiModuleCompilation) { + cpu::CpuCompiler compiler; + TestMultiModuleCompilation(&compiler); +} + +TEST_F(GpuCompilerTest, MultModuleCompilation) { + gpu::GpuCompiler compiler; + TestMultiModuleCompilation(&compiler); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 329b53012f58c8d084cc05f9a567a8aa432c4a3a..ad71d40197fe48b4343ee5f5f7f71b282a05cbf5 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -136,16 +136,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToShapedBuffer( - *test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*minor_to_major=*/{0, 1})); + auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToShapedBuffer( - *test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}}, - /*minor_to_major=*/{1, 0})); + auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(), LayoutUtil::MakeLayout({1, 0}))); @@ -876,11 +874,13 @@ XLA_TEST_F(LocalClientExecuteTest, tensorflow::ThreadOptions(), "execute_thread", [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); - ASSERT_IS_OK(local_client_->TransferToInfeed( - *Literal::CreateR1({-5.0, 123.0, 42.0}))); + ASSERT_IS_OK(local_client_->TransferToInfeedLocal( + *Literal::CreateR1({-5.0, 123.0, 42.0}), + local_client_->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - local_client_->TransferFromOutfeed(&shape)); + local_client_->TransferFromOutfeedLocal( + shape, local_client_->default_device_ordinal())); LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); } @@ -906,9 +906,12 @@ void BM_LocalClientOverhead(int num_iters) { builder.Add(x, x); auto computation = builder.Build().ConsumeValueOrDie(); - auto buffer = - ScopedShapedBuffer::Allocate(shape, &allocator, /*device_ordinal=*/0) - .ConsumeValueOrDie(); + auto shape_size_fn = [client](const Shape& shape) { + return client->backend().transfer_manager()->GetByteSizeRequirement(shape); + }; + auto buffer = ScopedShapedBuffer::Allocate( + shape, &allocator, /*device_ordinal=*/0, shape_size_fn) + .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( executors[device_ordinal], *literal, buffer->mutable_buffer({}))); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index c11e1df0a7890a6c3aada5ff47494b42fdaf3b9d..062a9246e49598d5d03dce8c1f437138923449bf 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/map_util.h" @@ -136,29 +135,10 @@ std::unique_ptr LocalClientTestBase::LiteralToShapedBuffer( .ConsumeValueOrDie(); } -void LocalClientTestBase::CopyShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer, ShapeIndex* index, Literal* literal) { - const Shape& shape = ShapeUtil::GetSubshape(shaped_buffer.shape(), *index); - if (ShapeUtil::IsTuple(shape)) { - *literal->mutable_shape() = shape; - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - Literal* element_literal = literal->add_tuple_literals(); - index->push_back(i); - CopyShapedBufferToLiteral(shaped_buffer, index, element_literal); - index->pop_back(); - } - } else { - ASSERT_IS_OK(transfer_manager_->TransferLiteralFromDevice( - stream_executor_, shaped_buffer.buffer(*index), shape, shape, literal)); - } -} - std::unique_ptr LocalClientTestBase::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { - auto literal = MakeUnique(); - ShapeIndex index; - CopyShapedBufferToLiteral(shaped_buffer, &index, literal.get()); - return literal; + return local_client_->ShapedBufferToLiteral(shaped_buffer) + .ConsumeValueOrDie(); } ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions() diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 3edfcb656ed8278d403103f0cfd820a10892476a..f0c73f04f6eb67b2e9cb5e111eccdc3818059b2b 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -93,10 +93,6 @@ class LocalClientTestBase : public ::testing::Test { std::unique_ptr ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); - // Helper for converting a ShapedBuffer into a literal. - void CopyShapedBufferToLiteral(const ShapedBuffer& shaped_buffer, - ShapeIndex* index, Literal* literal); - // Execute the given computation on the local client. With and without // options. StatusOr> ExecuteLocally( diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 2ef392508d14cf6dc14b2c979f07a79bc60d7426..2b0f7e6e80c48435ca55432a2afa3b6d69162625 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -405,13 +405,13 @@ TEST_F(MapTest, MapBinaryAdder) { // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = - test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0}); + std::unique_ptr param0_literal = Literal::CreateR2WithLayout( + {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = - test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1}); + std::unique_ptr param1_literal = Literal::CreateR2WithLayout( + {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 22d2b917a1d55f4f453e21c2d8fea38e32ff796b..89fa6ed9f7fe590f3ac872cce48a329b2894048a 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -76,8 +76,11 @@ class MultiOutputFusionTest : public HloTestBase { elem_shape2, HloOpcode::kAdd, broadcast, param1)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( elem_shape2, HloOpcode::kSubtract, param1, broadcast)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(elem_shape2, HloOpcode::kDot, sub, add2)); + HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -133,8 +136,11 @@ class MultiOutputFusionTest : public HloTestBase { HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {size, 1}), add)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1}), HloOpcode::kDot, sub, reshape)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index fda4389f479cdc7a659e4d7c8a2facba55e17e83..c260258d6e9af6dee6075c92cf35dac4ed46abed 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -252,8 +252,8 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { } // Only run the 3,000-parameter tests in opt mode to avoid test timeouts. -// Timeout last observed on 2017-09-12. -#ifndef NDEBUG +// Timeout last observed on 2017-11-20. +#ifdef NDEBUG // TODO(b/65525254) Fails on GPU on 2017-09-10 because we try to reserve too // much space in parameter memory for the kernel. @@ -334,6 +334,106 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); } +// Test large number of parameters flowing into a while-loop. +// Construct conceptually the following HLO graph: +// +// p0 = parameter(0) +// p1 = parameter(1) +// ... +// pN = parameter(N) +// result = while (false) { +// p0 += (1, 1); +// p1 += (1, 1); +// ... +// pN += (1, 1) +// } +// result = {p0, p1, ..., pN} +// +// TODO(b/70173746): Times out during compilation on GPU and CPU backends as of +// 2017-12-12. +XLA_TEST_F(ParamsTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(ManyParametersIntoWhileLoop))) { + ComputationBuilder builder(client_, TestName()); + + std::vector> param_data_owner; + constexpr int kParamCount = 1900; + std::vector params; + std::vector parameter_shapes; + for (int i = 0; i < kParamCount; ++i) { + std::unique_ptr literal = Literal::CreateR1({i, i}); + param_data_owner.push_back( + std::move(client_->TransferToServer(*literal)).ValueOrDie()); + ComputationDataHandle param = + builder.Parameter(i, literal->shape(), "param"); + params.push_back(param); + parameter_shapes.push_back(literal->shape()); + } + + // Add bool parameter for the loop condition. Use a parameter HLO instead of a + // constant because DCE may eliminate the while-body otherwise. + std::unique_ptr bool_literal = Literal::CreateR0(false); + param_data_owner.push_back( + std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); + ComputationDataHandle bool_param = + builder.Parameter(kParamCount, bool_literal->shape(), "bool_param"); + params.push_back(bool_param); + parameter_shapes.push_back(bool_literal->shape()); + + auto init = builder.Tuple(params); + + // Create a computation for the condition: while(bool_param). + Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes); + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto condition_parameter = + builder.Parameter(0, while_shape, "condition_parameter"); + builder.GetTupleElement(condition_parameter, kParamCount); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add {1, 1} to the each tuple element. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto body_parameter = builder.Parameter(0, while_shape, "body_parameter"); + std::vector updates; + for (int i = 0; i < kParamCount; ++i) { + auto add = builder.Add(builder.GetTupleElement(body_parameter, i), + builder.ConstantR1({1, 1})); + updates.push_back(add); + } + // Add bool parameter. + updates.push_back(builder.GetTupleElement(body_parameter, kParamCount)); + + builder.Tuple(updates); + body = builder.Build().ConsumeValueOrDie(); + } + + auto loop = builder.While(condition, body, init); + + std::vector outputs; + for (int i = 0; i < kParamCount; ++i) { + outputs.push_back(builder.GetTupleElement(loop, i)); + } + builder.Tuple(outputs); + + std::vector param_data; + param_data.reserve(param_data_owner.size()); + for (const std::unique_ptr& data : param_data_owner) { + param_data.push_back(data.get()); + } + + std::vector> elements; + std::vector ptrs; + for (int i = 0; i < kParamCount; ++i) { + elements.push_back(Literal::CreateR1({i, i})); + ptrs.push_back(elements.back().get()); + } + ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); +} + #endif XLA_TEST_F(ParamsTest, diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 7bc3185c367f076c9a7d211c9799557e1a91d92f..b09ccdd679b6c8f628e40f78f58dbd1734926af6 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -352,15 +352,13 @@ XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) { XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); } XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceAllOnesR1_10_Pred)) { +XLA_TEST_F(ReduceTest, AndReduceAllOnesR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count, 1); RunR1ToR0PredTest(/*and_reduce=*/true, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceOnesAndZerosR1_10_Pred)) { +XLA_TEST_F(ReduceTest, AndReduceOnesAndZerosR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count); for (int i = 0; i < element_count; ++i) { @@ -369,15 +367,13 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceOnesAndZerosR1_10_Pred)) { RunR1ToR0PredTest(/*and_reduce=*/true, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceAllOnesR1_10_Pred)) { +XLA_TEST_F(ReduceTest, OrReduceAllOnesR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count, 1); RunR1ToR0PredTest(/*and_reduce=*/false, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceOnesAndZerosR1_10_Pred)) { +XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count); for (int i = 0; i < element_count; ++i) { diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 6c9b62b48d8bb2ad93b2ce98839e5e52d8eaa8cc..b32df74312ed1b513bcdd161c1516c5a5a2f0faf 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -41,16 +41,40 @@ limitations under the License. namespace xla { namespace { -class ReduceWindowTest : public ClientLibraryTestBase { +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +// Tests both F32 and BF16. +static std::array use_bfloat16_params{false, true}; +#else +// Only tests F32. +static std::array use_bfloat16_params{false}; +#endif + +class ReduceWindowTestBase : public ClientLibraryTestBase { public: - ReduceWindowTest() : builder_(client_, TestName()) {} + ErrorSpec DefaultErrorSpec() const { + if (use_bfloat16()) { + return ErrorSpec(1e-1, 5e-2); + } else { + return ErrorSpec(1e-3, 1e-3); + } + } +}; + +class ReduceWindowTest : public ::testing::WithParamInterface, + public ReduceWindowTestBase { + public: + ReduceWindowTest() : builder_(client_, TestName()) { + set_use_bfloat16(GetParam()); + } void ReduceWindowAdd(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow(input, builder_.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder_), + auto init = + CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); } @@ -58,30 +82,32 @@ class ReduceWindowTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow( - input, builder_.ConstantLiteral(Literal::MinValue(F32)), - CreateScalarMax(), window_dimensions, window_strides, padding); + auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); + builder_.ReduceWindow(input, init, CreateScalarMax(), window_dimensions, + window_strides, padding); } void ReduceWindowMin(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow(input, - builder_.ConstantLiteral(Literal::MaxValue(F32)), - CreateScalarMinComputation(F32, &builder_), + auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarMinComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); } ComputationBuilder builder_; }; -TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { - const auto input = builder_.ConstantR1({1, 1, 1, 1}); - const auto init_value = builder_.ConstantR0(0); +TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({1, 1, 1, 1}), &builder_); + const auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); builder_.ReduceWindow(input, init_value, - CreateScalarAddComputation(F32, &builder_), + CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{1, 2}, /*window_strides=*/{1}, Padding::kValid); ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) @@ -90,79 +116,97 @@ TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { ::testing::HasSubstr("Want input dimensions size")); } -TEST_F(ReduceWindowTest, Min3In5Stride2) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); +// Regression test for b/68964348. +TEST_P(ReduceWindowTest, R0ReduceWindow) { + const auto input = + CreateConstantFromLiteral(*Literal::CreateR0(42.0), &builder_); + const auto init = + CreateConstantFromLiteral(*Literal::CreateR0(1.0), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), + /*window_dimensions=*/{}, + /*window_strides=*/{}, Padding::kSame); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR0(43.0), {}, + ErrorSpec(0.00001)); +} + +TEST_P(ReduceWindowTest, Min3In5Stride2) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {100, 1}, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({100, 1}), {}, + ErrorSpec(0.00001)); } -XLA_TEST_F(ReduceWindowTest, ZeroElementSmall) { +XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { Array4D input_array(1, 0, 2, 1); - - const auto input = builder_.ConstantR4FromArray4D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, NonSquareSmall) { +TEST_P(ReduceWindowTest, NonSquareSmall) { Array4D input_array(1, 2, 2, 1); - input_array.FillRandom(2.f); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, MiddleDimsSmall) { +TEST_P(ReduceWindowTest, MiddleDimsSmall) { Array4D input_array(1, 3, 3, 1); - input_array.FillRandom(2.f); - - const auto input = builder_.ConstantR4FromArray4D(input_array); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, Along2ndMinorDim) { +TEST_P(ReduceWindowTest, Along2ndMinorDim) { Array4D input_array(3, 6, 7, 32); - input_array.FillRandom(2.f); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); // The parameters of this reduction mimic feature norm (e.g. LRN). int lrn_diameter = 7; // diameter = 2*radius + 1 --> must be odd - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2Dims) { +TEST_P(ReduceWindowTest, AmongMajor2Dims) { Array4D input_array(4, 4, 6, 8); input_array.FillWithMinorDimNum(); + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); int win_len = 3; int win_stride = 1; Padding padding = Padding::kSame; - const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); // Reduce only along the x and y dimensions, according to the win_len. ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); @@ -170,18 +214,20 @@ TEST_F(ReduceWindowTest, AmongMajor2Dims) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { +TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { Array4D input_array(9, 12, 4, 89); - input_array.FillRandom(2.0f); + input_array.FillRandom(2.f, 2.f); int win_len = 3; int win_stride = 2; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; // Reduce only along the x and y dimensions, according to the win_len. @@ -192,20 +238,21 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // TODO(b/32173947): Test support for arbitrary-sized padding. -TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { +TEST_P(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { Array4D input_array(9, 12, 4, 89); // simulate Dim0IsMinor layout - input_array.FillRandom(2.0f); + input_array.FillRandom(2.f, 2.f); int64 rank = 4; int win_len = 3; int win_stride = 2; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; // Reduce only along the x and y dimensions, according to the win_len. @@ -222,26 +269,28 @@ TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) { +XLA_TEST_P(ReduceWindowTest, Add1x1x2In2x1x2) { Array3D input_array(2, 1, 2); input_array(0, 0, 0) = 1000; input_array(0, 0, 1) = 100; input_array(1, 0, 0) = 10; input_array(1, 0, 1) = 1; - auto input = builder_.ConstantR3FromArray3D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kValid); Array3D expected(2, 1, 1); expected(0, 0, 0) = 1100; expected(1, 0, 0) = 11; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(expected), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { +XLA_TEST_P(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { Array3D input_array(2, 1, 3); input_array(0, 0, 0) = 100; input_array(0, 0, 1) = 10; @@ -249,17 +298,18 @@ XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { input_array(1, 0, 0) = 500; input_array(1, 0, 1) = 50; input_array(1, 0, 2) = 5; - auto input = builder_.ConstantR3FromArray3D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 2}, Padding::kValid); Array3D expected(2, 1, 1); expected(0, 0, 0) = 110; expected(1, 0, 0) = 550; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(expected), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { +XLA_TEST_P(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { Array3D input_array(2, 1, 3); input_array(0, 0, 0) = 100; input_array(0, 0, 1) = 10; @@ -267,7 +317,7 @@ XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { input_array(1, 0, 0) = 500; input_array(1, 0, 1) = 50; input_array(1, 0, 2) = 5; - auto input = builder_.ConstantR3FromArray3D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kSame); @@ -278,30 +328,34 @@ XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { expected(1, 0, 0) = 550; expected(1, 0, 1) = 55; expected(1, 0, 2) = 5; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(expected), {}, + DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. -XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { +XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Array4D input_array(1, 2, 2, 1); input_array(0, 0, 0, 0) = 1; input_array(0, 0, 1, 0) = 2; input_array(0, 1, 0, 0) = 3; input_array(0, 1, 1, 0) = 4; + const auto input = CreateConstantFromArray(input_array, &builder_); - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kValid; - - const Shape scalar = ShapeUtil::MakeShape(F32, {}); + const Shape scalar = ShapeUtil::MakeShape(FloatType(), {}); auto b = builder_.CreateSubBuilder("unusual"); auto lhs = b->Parameter(0, scalar, "lhs"); auto rhs = b->Parameter(1, scalar, "rhs"); - b->Min(b->Add(lhs, rhs), b->ConstantR0(8.0f)); + b->Min(b->Add(lhs, rhs), + CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); Computation reduce_fn = b->BuildAndNoteError(); - builder_.ReduceWindow(input, builder_.ConstantR0(3.0f), reduce_fn, - /*window_dimensions=*/{1, 1, 2, 1}, - /*window_strides=*/{1, 1, 1, 1}, padding); + builder_.ReduceWindow( + input, + CreateConstantFromLiteral(*Literal::CreateR0(3.0f), &builder_), + reduce_fn, + /*window_dimensions=*/{1, 1, 2, 1}, + /*window_strides=*/{1, 1, 1, 1}, padding); const auto reduce_func = [](float arg1, float arg2) { return std::min(arg1 + arg2, 8.0f); @@ -312,17 +366,19 @@ XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *expected, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, R4UnitWindow) { +TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); - input_array.Fill(1.0f); + input_array.FillRandom(2.f, 2.f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -330,15 +386,11 @@ TEST_F(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { - auto b = HloComputation::Builder(TestName()); - +XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); @@ -348,56 +400,15 @@ XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { }; TF_EXPECT_OK(arg_literal->Populate(generator)); - auto input = - b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); - - auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); - - HloComputation::Builder add_computation("add"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - add_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - - auto module = CreateNewModule(); - auto add_func = module->AddEmbeddedComputation(add_computation.Build()); - - WindowDimension trivial_dim; - trivial_dim.set_size(1); - trivial_dim.set_stride(1); - trivial_dim.set_padding_low(0); - trivial_dim.set_padding_high(0); - trivial_dim.set_window_dilation(1); - trivial_dim.set_base_dilation(1); - - WindowDimension active_dim; - active_dim.set_size(3); - active_dim.set_stride(1); - active_dim.set_padding_low(0); - active_dim.set_padding_high(0); - active_dim.set_window_dilation(1); - active_dim.set_base_dilation(1); - - Window window; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - - // Non-monotonic output layout with minor dims trivial. + const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + + Padding padding = Padding::kValid; + ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); + std::vector output_layout = {1, 5, 3, 2, 0, 4}; std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - b.AddInstruction(HloInstruction::CreateReduceWindow( - result_shape, input, init_value, window, add_func)); - std::unique_ptr expected = Literal::CreateFromShape(result_shape); auto out_generator = [&](tensorflow::gtl::ArraySlice indexes) -> float { @@ -405,82 +416,37 @@ XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { }; TF_EXPECT_OK(expected->Populate(out_generator)); - module->AddEntryComputation(b.Build()); - auto actual = ExecuteAndTransfer(std::move(module), {}); - - LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(HloTestBase, R6Add) { - auto b = HloComputation::Builder(TestName()); - +XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector input_dims(6, 8); + auto shape = ShapeUtil::MakeShape(F32, input_dims); + std::unique_ptr arg_literal = Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); - auto input = - b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); - - auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); - - HloComputation::Builder add_computation("add"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - add_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - - auto module = CreateNewModule(); - auto add_func = module->AddEmbeddedComputation(add_computation.Build()); - - WindowDimension trivial_dim; - trivial_dim.set_size(1); - trivial_dim.set_stride(1); - trivial_dim.set_padding_low(0); - trivial_dim.set_padding_high(0); - trivial_dim.set_window_dilation(1); - trivial_dim.set_base_dilation(1); - - WindowDimension active_dim; - active_dim.set_size(3); - active_dim.set_stride(1); - active_dim.set_padding_low(0); - active_dim.set_padding_high(0); - active_dim.set_window_dilation(1); - active_dim.set_base_dilation(1); - - Window window; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - - Shape shape = ShapeUtil::MakeShape(F32, {8, 8, 6, 6, 8, 8}); - b.AddInstruction(HloInstruction::CreateReduceWindow(shape, input, init_value, - window, add_func)); + + const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + + Padding padding = Padding::kValid; + ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); std::vector output_dims = {8, 8, 6, 6, 8, 8}; std::unique_ptr expected = Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 9.0f); - module->AddEntryComputation(b.Build()); - auto actual = ExecuteAndTransfer(std::move(module), {}); - - LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 1; int stride = 8; @@ -490,20 +456,19 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 3; int stride = 1; @@ -513,20 +478,19 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 8; int stride = 5; @@ -536,13 +500,11 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { +TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { Array4D input_array(6, 4, 10, 130); input_array.FillRandom(2.0f); @@ -551,7 +513,7 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { Padding padding = Padding::kSame; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); // Reduce only along the x and y dimensions, according to the win_len. ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); @@ -559,36 +521,42 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add24In1152_NoOverlap) { +XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); - const auto input = builder_.ConstantR1(input_vector); + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); - ComputeAndCompareR1(&builder_, {32, 32, 32, 32, 32, 32, 32, 32, 32}, - {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral( + &builder_, + *Literal::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add128In128Stride128) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); +XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { + std::vector input_vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareR1(&builder_, {1088}, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + DefaultErrorSpec()); } // Regression test for a bug that appeared in Inception (b/34784899). -TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { +TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { Array2D input_array(14, 14, 1.0f); - ComputationDataHandle input = - builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {14, 14}); + const auto input = CreateConstantFromArray(input_array, &builder_); int win_len = 3; int stride = 1; @@ -598,13 +566,14 @@ TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {}, DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { +TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); - ComputationDataHandle input = - builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {6, 4}); + ComputationDataHandle input = builder_.Broadcast( + CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); @@ -612,9 +581,13 @@ TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {}, DefaultErrorSpec()); } +INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, + ::testing::ValuesIn(use_bfloat16_params)); + enum Reducer { kAdd, kMax }; struct R4ReduceWindowTestData { @@ -628,30 +601,36 @@ struct R4ReduceWindowTestData { }; string R4ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(data.param.pad_low, "x"), // - "__pad_high_", tensorflow::str_util::Join(data.param.pad_high, "x"), // - (data.param.reducer == kAdd) ? "add" : "max"); - CHECK(data.param.reducer == kAdd || data.param.reducer == kMax); + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), // + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), // + (param.reducer == kAdd) ? "add" : "max"); + CHECK(param.reducer == kAdd || param.reducer == kMax); // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R4ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface { +class R4ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: + R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + void DoIt() { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); const float kInitValue = 0.0f; @@ -660,23 +639,24 @@ class R4ReduceWindowTest input.FillIota(1); std::unique_ptr input_literal = Literal::CreateR4FromArray4D(input); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); std::vector> padding(4); for (int i = 0; i < 4; ++i) { padding[i] = {param.pad_low[i], param.pad_high[i]}; } - auto parameter = b.Parameter(0, input_literal->shape(), "p0"); - auto pad_value = b.ConstantR0(kInitValue); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto computation = param.reducer == kAdd - ? CreateScalarAddComputation(F32, &b) - : CreateScalarMaxComputation(F32, &b); + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); b.ReduceWindowWithGeneralPadding( /*operand=*/parameter, - /*init_value=*/pad_value, + /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, @@ -694,8 +674,8 @@ class R4ReduceWindowTest /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareR4(&b, *expected, {input_arg.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); } }; @@ -824,9 +804,11 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*reducer=*/kAdd}, }; -INSTANTIATE_TEST_CASE_P(R4ReduceWindowTestInstantiation, R4ReduceWindowTest, - ::testing::ValuesIn(kR4ReduceWindowTestValues), - R4ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowTestInstantiation, R4ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); class R4ReduceWindowLargeTest : public R4ReduceWindowTest {}; @@ -849,10 +831,11 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { /*reducer=*/kAdd}, }; -INSTANTIATE_TEST_CASE_P(R4ReduceWindowLargeTestInstantiation, - R4ReduceWindowLargeTest, - ::testing::ValuesIn(kR4ReduceWindowLargeTestValues), - R4ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); struct R2ReduceWindowTestData { int64 base_bounds[2]; @@ -900,26 +883,33 @@ struct R2ReduceWindowTestData { }; string R2ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // - "__layout_", data.param.layout[0], "_", data.param.layout[1], // - "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__padding_", param.padding == Padding::kSame ? "same" : "valid", // + "__layout_", param.layout[0], "_", param.layout[1], // + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R2ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface {}; +class R2ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } +}; TEST_P(R2ReduceWindowTest, Add) { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; @@ -927,12 +917,15 @@ TEST_P(R2ReduceWindowTest, Add) { std::unique_ptr input_literal = Literal::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); - b.ReduceWindow(/*operand=*/ - b.Parameter(0, input_literal->shape(), "p0"), - /*init_value=*/b.ConstantR0(kInitValue), - /*computation=*/CreateScalarAddComputation(F32, &b), + + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, /*padding=*/param.padding); @@ -940,90 +933,145 @@ TEST_P(R2ReduceWindowTest, Add) { /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/param.padding); - ComputeAndCompareR2(&b, *expected, {input_arg.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); } -INSTANTIATE_TEST_CASE_P(R2ReduceWindowTestInstantiation, R2ReduceWindowTest, - ::testing::ValuesIn(kR2TestCases), - R2ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R2ReduceWindowTestInstantiation, R2ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR2TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R2ReduceWindowTestDataToString); struct R1ReduceWindowTestData { int64 base_bounds[1]; int64 window_bounds[1]; int64 strides[1]; - Padding padding; + int64 pad_low[1]; + int64 pad_high[1]; Reducer reducer; } kR1TestCases[] = { {/*base_bounds=*/{1}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{3}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{4}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{30}, + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{30}, /*strides=*/{27}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 17}, /*window_bounds=*/{7}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 17}, + /*window_bounds=*/{7}, /*strides=*/{64}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{32}, + /*pad_low=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{32}, /*strides=*/{56}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{3}, /*strides=*/{2}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{0}, + /*pad_high=*/{5}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{5}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kAdd}, }; string R1ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // - "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), + "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), + "__strides_", tensorflow::str_util::Join(param.strides, "x"), + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R1ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface {}; +class R1ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } +}; TEST_P(R1ReduceWindowTest, DoIt) { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd || param.reducer == kMax); const float kInitValue = 0.0f; @@ -1031,18 +1079,24 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + + std::vector> padding(1); + padding[0] = {param.pad_low[0], param.pad_high[0]}; auto computation = param.reducer == kAdd - ? CreateScalarAddComputation(F32, &b) - : CreateScalarMaxComputation(F32, &b); - b.ReduceWindow(/*operand=*/ - b.Parameter(0, input_literal->shape(), "p0"), - /*init_value=*/b.ConstantR0(kInitValue), - /*computation=*/computation, - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindowWithGeneralPadding( + /*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } @@ -1052,14 +1106,17 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*init=*/kInitValue, /*reduce_func=*/reduce_func, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); + /*stride=*/param.strides, + /*padding=*/padding); - ComputeAndCompareR1(&b, tensorflow::gtl::ArraySlice(*expected), - {input_arg.get()}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateR1(*expected), + {input_arg.get()}, DefaultErrorSpec()); } -INSTANTIATE_TEST_CASE_P(R1ReduceWindowTestInstantiation, R1ReduceWindowTest, - ::testing::ValuesIn(kR1TestCases), - R1ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R1ReduceWindowTestInstantiation, R1ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR1TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R1ReduceWindowTestDataToString); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 72c68f24a0a954deb0564e9a0e924edfaf5b5484..ddd50d7a5864d73de7916ce736bb7cd40c1c4dc9 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -41,326 +41,467 @@ limitations under the License. namespace xla { namespace { -class ReshapeTest : public ClientLibraryTestBase { +// Use a bool parameter to indicate whether to use bfloat16. +class ReshapeTest : public ::testing::WithParamInterface, + public ClientLibraryTestBase { public: + ReshapeTest() { set_use_bfloat16(GetParam()); } + ErrorSpec zero_error_spec_{0.0}; }; // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. -XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) { +XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0}}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + Array2D input_array(1, 1); + input_array.Fill(1.0f); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, CollapseTrivialR1EmptyDims) { +XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, CollapseTrivialR1OnlyDim) { +XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{0}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. -XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { +XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0}}); - auto reshape = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); + Array2D input_array(1, 1); + input_array.Fill(1.0f); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); - ComputeAndCompareR0(&builder, 1.0f, {}, zero_error_spec_); + auto expected_literal = Literal::CreateR0(1.0f); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { +XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = Literal::CreateR0(1.0f); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - - auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); - a = builder.Neg(a); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + &builder, ¶meter); + auto a = builder.Neg(parameter); auto reshape = builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); - ComputeAndCompareR1(&builder, {-1.0f}, {param0_data.get()}, - zero_error_spec_); + auto expected_literal = Literal::CreateR1({-1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, Trivial0x3) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 3)); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); + Array2D input_array(0, 3); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // TODO(b/29185393): Make this work with the GPU backend. The GPU backend // does not handle zero-sized shapes correctly. Failed last on 2017-05-15 // with an incorrect result rank. -XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = Literal::CreateR2FromArray2D(Array2D(0, 3)); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - - auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0, 3}), "param0"); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {param0_data.get()}, - zero_error_spec_); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, Trivial3x0) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(3, 0)); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); + Array2D input_array(3, 0); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses a 2-dimensional row vector to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial1x3) { +XLA_TEST_P(ReshapeTest, Trivial1x3) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, - zero_error_spec_); + auto input_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses a 2-dimensional column vector to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial3x1) { +XLA_TEST_P(ReshapeTest, Trivial3x1) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0f}, {2.0f}, {3.0f}}); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, - zero_error_spec_); + auto input_literal = Literal::CreateR2({{1.0f}, {2.0f}, {3.0f}}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Splits an empty vector into an empty matrix. -XLA_TEST_F(ReshapeTest, R1ToR2_0_To_2x0) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({}); - auto result = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); - ComputeAndCompareR2(&builder, Array2D(2, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateR1({}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 0}); + auto expected_literal = Literal::CreateR2({{}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Splits a vector into a matrix. -XLA_TEST_F(ReshapeTest, R1ToR2_6_To_2x3) { +XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - auto result = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); - Array2D expected_2x3({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - ComputeAndCompareR2(&builder, expected_2x3, {}, zero_error_spec_); + auto input_literal = + Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 3}); + auto expected_literal = + Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Transposes a 2x0 array to a 0x2 array. -XLA_TEST_F(ReshapeTest, Reshape0x2To2x0) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 0}); - - ComputeAndCompareR2(&builder, Array2D(2, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 0}); + auto expected_literal = Literal::CreateR2({{}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional row vector to a column vector. -XLA_TEST_F(ReshapeTest, ReshapeRowToCol) { +XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { ComputationBuilder builder(client_, TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); - auto a = builder.ConstantR2FromArray2D(*simple); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{3, 1}); + auto input_literal = Literal::CreateFromArray(*simple); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); - ComputeAndCompareR2(&builder, *expected, {}, zero_error_spec_); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional array. -XLA_TEST_F(ReshapeTest, TransposeAsReshape) { +XLA_TEST_P(ReshapeTest, TransposeAsReshape) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 4}); - - auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); - ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 4}); + + auto expected = ReferenceUtil::TransposeArray2D(*a4x3); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Transposes a 0x4 array with ComputationBuilder::Trans. -XLA_TEST_F(ReshapeTest, Transpose0x4) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 4)); - auto result = builder.Transpose(a, {1, 0}); - - ComputeAndCompareR2(&builder, Array2D(4, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Transpose(parameter, {1, 0}); + auto expected_literal = Literal::CreateR2({{}, {}, {}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional array with ComputationBuilder::Trans. -XLA_TEST_F(ReshapeTest, Transpose4x3) { +XLA_TEST_P(ReshapeTest, Transpose4x3) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Transpose(a, {1, 0}); - - auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); - ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Transpose(parameter, {1, 0}); + + auto expected = ReferenceUtil::TransposeArray2D(*a4x3); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(6, 0)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 3, 0, 0}); - - ComputeAndCompareR4(&builder, Array4D(2, 3, 0, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 3, 0, 0}); + auto expected_literal = Literal::CreateFromArray(Array4D(2, 3, 0, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, ReshapeR4ToR2ZeroElements) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR4FromArray4D(Array4D(2, 3, 4, 0)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{24, 0}); - - ComputeAndCompareR2(&builder, Array2D(24, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{24, 0}); + auto expected_literal = Literal::CreateFromArray(Array2D(24, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffle) { +XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 6}); - - auto expected2x6 = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); - ComputeAndCompareR2(&builder, *expected2x6, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 6}); + + auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -// Reshapes a 2-dimensional array with dimensions that are not just a -// rearrangement of the originals (split), and reorder the input (shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 6)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 0}); - - ComputeAndCompareR2(&builder, Array2D(3, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 0}); + auto expected_literal = Literal::CreateFromArray(Array2D(3, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), and reorder the input (shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffle) { +XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{2, 6}); - - Array2D expected2x6({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, - {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); - ComputeAndCompareR2(&builder, expected2x6, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{2, 6}); + Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, + {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); + auto expected_literal = Literal::CreateFromArray(expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // The following tests use the same input 3D array; they test the examples we // show for the Reshape operation in the operation_semantics document. // TODO(b/34503277): find a way to show this code in the documentation without // duplication on the TF documentation server. -Array3D v_array_for_doc_R3_tests({{{10, 11, 12}, {15, 16, 17}}, - {{20, 21, 22}, {25, 26, 27}}, - {{30, 31, 32}, {35, 36, 37}}, - {{40, 41, 42}, {45, 46, 47}}}); - -XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_012) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{24}); - ComputeAndCompareR1(&builder, - {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, - 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}, - {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{8, 3}); - Array2D expected({{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}); - ComputeAndCompareR2(&builder, expected, {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_120) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{24}); - ComputeAndCompareR1(&builder, - {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, - 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}, - {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{8, 3}); - Array2D expected({{10, 20, 30}, - {40, 11, 21}, - {31, 41, 12}, - {22, 32, 42}, - {15, 25, 35}, - {45, 16, 26}, - {36, 46, 17}, - {27, 37, 47}}); - ComputeAndCompareR2(&builder, expected, {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{2, 6, 2}); - Array3D expected( +static Array3D ArrayForDocR3Tests() { + return Array3D({{{10, 11, 12}, {15, 16, 17}}, + {{20, 21, 22}, {25, 26, 27}}, + {{30, 31, 32}, {35, 36, 37}}, + {{40, 41, 42}, {45, 46, 47}}}); +} + +XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{24}); + auto expected_literal = Literal::CreateR1( + {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, + 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{8, 3}); + auto expected_literal = Literal::CreateR2({{10, 11, 12}, + {15, 16, 17}, + {20, 21, 22}, + {25, 26, 27}, + {30, 31, 32}, + {35, 36, 37}, + {40, 41, 42}, + {45, 46, 47}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{24}); + auto expected_literal = Literal::CreateR1( + {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, + 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{8, 3}); + auto expected_literal = Literal::CreateR2({{10, 20, 30}, + {40, 11, 21}, + {31, 41, 12}, + {22, 32, 42}, + {15, 25, 35}, + {45, 16, 26}, + {36, 46, 17}, + {27, 37, 47}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{2, 6, 2}); + auto expected_literal = Literal::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); - ComputeAndCompareR3(&builder, expected, {}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses the low dimensions of a 4D tensor to get a 2D matrix, without @@ -378,23 +519,26 @@ XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { // Then we collapse Z be collapsed so we just end up with planes: // // 1 2 3 4 5 6 1 2 3 4 5 6 -XLA_TEST_F(ReshapeTest, FullyConnectedCollapse) { +XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { ComputationBuilder builder(client_, TestName()); Array4D t2x2x2x3(2, 2, 2, 3); auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); t2x2x2x3.FillWithYX(*filler2x3); - auto a = builder.ConstantR4FromArray4D(t2x2x2x3); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{1, 2, 3}); - - Array2D expected2x12( + auto input_literal = Literal::CreateFromArray(t2x2x2x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); + auto expected_literal = Literal::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); - ComputeAndCompareR2(&builder, expected2x12, {}, zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // As above, but uses reshape directly. -XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { +XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { ComputationBuilder builder(client_, TestName()); Array4D t(2, 1, 2, 2); t(0, 0, 0, 0) = 0; @@ -405,51 +549,67 @@ XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 0, 1) = 5; t(1, 0, 1, 0) = 6; t(1, 0, 1, 1) = 7; - auto a = builder.ConstantR4FromArray4D(t); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{2, 4}); - - Array2D expected({{0, 1, 2, 3}, {4, 5, 6, 7}}); - ComputeAndCompareR2(&builder, expected, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(t); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{2, 4}); + + auto expected_literal = + Literal::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshape various ranks to a scalar. -XLA_TEST_F(ReshapeTest, ToScalar) { +XLA_TEST_P(ReshapeTest, ToScalar) { for (int rank = 0; rank < 8; ++rank) { ComputationBuilder b(client_, TestName()); - auto input = Literal::CreateR1({83.0f}); + auto input_literal = Literal::CreateR1({83.0f}); std::vector ones(rank, 1); // this is {1, ..., 1}. std::vector dimensions(rank); std::iota(dimensions.begin(), dimensions.end(), 0); - *input->mutable_shape() = ShapeUtil::MakeShape(F32, ones); - b.Reshape(b.ConstantLiteral(*input), dimensions, {}); + *input_literal->mutable_shape() = ShapeUtil::MakeShape(F32, ones); + + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &b, ¶meter); + b.Reshape(parameter, dimensions, {}); - ComputeAndCompareR0(&b, 83.0f, {}, zero_error_spec_); + auto expected_literal = Literal::CreateR0(83.0f); + ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, + zero_error_spec_); } } -XLA_TEST_F(ReshapeTest, BadDimensions) { +XLA_TEST_P(ReshapeTest, BadDimensions) { ComputationBuilder b(client_, TestName()); - b.Reshape(b.ConstantR1({1}), {}, {}); - EXPECT_THAT(ExecuteToString(&b, {}), - ::testing::HasSubstr("dimensions not a permutation")); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + ¶meter); + b.Reshape(parameter, {}, {}); + EXPECT_THAT( + ExecuteToString(&b, {}), + ::testing::HasSubstr("not a permutation of the operand dimensions")); } -XLA_TEST_F(ReshapeTest, BadNewSizes) { +XLA_TEST_P(ReshapeTest, BadNewSizes) { ComputationBuilder b(client_, TestName()); - b.Reshape(b.ConstantR1({1, 2}), {1}, {}); + auto input_literal = Literal::CreateR1({1.0f, 2.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + ¶meter); + b.Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), ::testing::HasSubstr("mismatched element counts")); } -XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { - const Shape parameter_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); +XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, parameter_shape, "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); - // clang-format off - auto literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ + auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ { { {0, 1}, @@ -473,8 +633,12 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { }, LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on - std::unique_ptr input = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); + Array2D expected_array({ {0, 1, 2, 3, 100, 101, 102, 103}, {222, 333, 444, 555, 666, 777, 888, 999}, @@ -483,72 +647,75 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { Computation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithLayout(F32, {2, 8}, {1, 0}); + ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, + {1, 0}); std::unique_ptr actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); std::unique_ptr expected = Literal::CreateR2FromArray2D(expected_array); + if (use_bfloat16()) { + expected = LiteralTestUtil::ConvertF32ToBF16(*expected); + } LiteralTestUtil::ExpectEqual(*expected, *actual); } -XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { - std::unique_ptr input = Literal::CreateR2({ +XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - std::unique_ptr input_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - Array4D expected = { + auto expected_literal = Literal::CreateR4({ {{{0, 1, 2, 3}}, {{4, 5, 6, 7}}}, {{{100, 101, 102, 103}}, {{104, 105, 106, 107}}}, {{{200, 201, 202, 203}}, {{204, 205, 206, 207}}} - }; + }); // clang-format on - ComputeAndCompareR4(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Tests R2->R4 reshape with the reshape dimensions {1, 0}. -XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { - std::unique_ptr input = Literal::CreateR2({ +XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - std::unique_ptr input_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - Array4D expected = { + auto expected_literal = Literal::CreateR4({ {{{0, 100, 200, 1}}, {{101, 201, 2, 102}}}, {{{202, 3, 103, 203}}, {{4, 104, 204, 5}}}, {{{105, 205, 6, 106}}, {{206, 7, 107, 207}}} - }; + }); // clang-format on - ComputeAndCompareR4(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { +XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 1, 1); @@ -558,12 +725,10 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); @@ -571,7 +736,8 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { zero_error_spec_); } -XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { +XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 4, 1); @@ -581,12 +747,10 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); @@ -595,7 +759,8 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { } // Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}. -XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { +XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(5, 10, 2, 3); @@ -605,12 +770,11 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, + /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { @@ -618,10 +782,12 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { *cell; }); auto expected = Literal::CreateR2FromArray2D(expected_array); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}); + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, NoopReshape) { +XLA_TEST_P(ReshapeTest, NoopReshape) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input_array(2, 3, 5, 7); @@ -631,18 +797,17 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Reshape(input, /*dimensions=*/{3, 0, 1, 2}, + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); Computation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithLayout(F32, {7, 2, 3, 5}, {2, 3, 0, 1}); + ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, + {2, 3, 0, 1}); std::unique_ptr output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, @@ -651,35 +816,45 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. - EXPECT_EQ(tensorflow::gtl::ArraySlice(input_literal->f32s()), - tensorflow::gtl::ArraySlice(output_literal->f32s())); + if (use_bfloat16()) { + auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal); + EXPECT_EQ(tensorflow::gtl::ArraySlice(expected->bf16s()), + tensorflow::gtl::ArraySlice(output_literal->bf16s())); + } else { + EXPECT_EQ(tensorflow::gtl::ArraySlice(input_literal->f32s()), + tensorflow::gtl::ArraySlice(output_literal->f32s())); + } } -XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { - auto literal_1x2x3x4 = Literal::CreateR4( +XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { + ComputationBuilder builder(client_, TestName()); + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - ComputationBuilder builder(client_, TestName()); - auto input = builder.ConstantLiteral(*literal_1x2x3x4); - builder.Reshape(input, /*dimensions=*/{0, 1, 2, 3}, + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {}); + ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); } -XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { - auto literal_1x2x3x4 = Literal::CreateR4( +XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); ComputationBuilder builder(client_, TestName()); - auto input = builder.ConstantLiteral(*literal_1x2x3x4); - builder.Reshape(input, /*dimensions=*/{1, 3, 2, 0}, + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); // clang-format off - auto expected_2x4x3x1 = Literal::CreateR4( + auto expected_2x4x3x1 = Literal::CreateR4( {{{{1}, {5}, {9}}, {{2}, {6}, {10}}, {{3}, {7}, {11}}, @@ -690,10 +865,10 @@ XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { {{16}, {20}, {24}}}}); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {}); + ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()}); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {2, 2, 2, 2}; @@ -705,12 +880,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -722,7 +897,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {1, 1, 250, 300}; @@ -734,12 +909,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -751,7 +926,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {5, 5, 1, 10}; @@ -763,12 +938,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -780,7 +955,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::mt19937 rng; std::uniform_real_distribution distribution; // This happens in NN-Builder MNIST. @@ -793,12 +968,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -810,7 +985,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {3, 3, 1, 3}; @@ -822,12 +997,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) @@ -839,5 +1014,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { zero_error_spec_, &expected->shape()); } +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, ::testing::Bool()); +#else +INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, + ::testing::ValuesIn(std::vector{false})); +#endif + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/sample_file_test.cc b/tensorflow/compiler/xla/tests/sample_file_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..31b104f4e37f77d47f56ff8183ee1de1cc22e44d --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_file_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create a file based testcase +// and compare results on gpu and cpu. + +#include +#include + +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class SampleFileTest : public HloTestBase { + protected: + SampleFileTest() + : HloTestBase( + /*test_platform=*/PlatformUtil::GetPlatform("gpu").ValueOrDie(), + /*reference_platform=*/PlatformUtil::GetPlatform("cpu") + .ValueOrDie()) {} +}; + +TEST_F(SampleFileTest, Convolution) { + const string& filename = "compiler/xla/tests/isolated_convolution.hlo"; + string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); + EXPECT_TRUE(RunAndCompareFromFile( + tensorflow::io::JoinPath(test_srcdir, filename), ErrorSpec{0.01})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/sample_text_test.cc b/tensorflow/compiler/xla/tests/sample_text_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4f2b74e3dc9e80f50454b28eb6f2502cef3e681 --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_text_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create textual IR based +// testcases. + +#include +#include + +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class SampleTextTest : public HloTestBase {}; + +TEST_F(SampleTextTest, Axpy) { + const string& hlo_string = R"( +HloModule axpy_module: +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0.0001})); +} + +TEST_F(SampleTextTest, Tuple) { + const string& hlo_string = R"( +HloModule TupleCreate_module: +ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, nullopt)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index c21124750ad512cad69b1483e708613ee2857ac0..4db566f7841829359ea06fe25408048418c547ad 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -211,6 +212,13 @@ class SliceR1Test : public ClientLibraryTestBase, } }; +string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { + const R1Spec& spec = data.param; + return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, + spec.slice_start, spec.slice_limit, + spec.slice_stride); +} + XLA_TEST_P(SliceR1Test, DoIt_F32) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_F64) { Run(GetParam()); } @@ -223,30 +231,66 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run(GetParam()); } -INSTANTIATE_TEST_CASE_P( // - SliceR1TestInstantiation, // - SliceR1Test, // - ::testing::Values( // - R1Spec{10, 0, 0, 1}, // - R1Spec{10, 7, 7, 1}, // - R1Spec{10, 2, 4, 1}, // - R1Spec{10, 2, 4, 2}, // - R1Spec{10, 0, 10, 1}, // - R1Spec{1024, 1024 - 4, 1024, 1}, // - R1Spec{4096, 7, 7 + 1024, 1}, // - R1Spec{10, 0, 10, 2}, // - R1Spec{10, 0, 10, 3}, // - R1Spec{10, 0, 10, 4}, // - R1Spec{10, 0, 10, 5}, // - R1Spec{10, 0, 10, 10}, // - R1Spec{500, 200, 400, 7}, // - R1Spec{4096, 1, 4095, 3}, // - R1Spec{2047, 1024 - 24, 1024 + 160, 31}, // - R1Spec{2047, 1, 2046, 3 * 128}, // - R1Spec{4096, 1024 + 3, 4095, 500}, // - R1Spec{8192, 0, 8192, 1024 * 3 + 400} // - ) // +// Tests for R1 slice ops. +// The format for each testcase is {input size, start, limit, stride}. +// clang-format off +INSTANTIATE_TEST_CASE_P( + SliceR1TestInstantiation, + SliceR1Test, + ::testing::Values( + R1Spec{10, 0, 0, 1}, + R1Spec{10, 7, 7, 1}, + R1Spec{10, 0, 5, 1}, + R1Spec{10, 3, 5, 1}, + R1Spec{10, 0, 10, 1}, + R1Spec{1024, 0, 5, 1}, + R1Spec{1024, 3, 5, 1}, + R1Spec{1024 + 17, 0, 5, 1}, + R1Spec{1024 + 17, 3, 5, 1}, + R1Spec{1024 + 17, 1024, 1024 + 6, 1}, + R1Spec{1024 + 17, 1024 + 1, 1024 + 6, 1}, + R1Spec{1024, 1024 - 4, 1024, 1}, + R1Spec{4 * 1024, 7, 7 + 1024, 1}, + R1Spec{4 * 1024, 0, 4 * 1024, 1}, + R1Spec{4 * 1024, 1, 4 * 1024 - 1, 1}, + R1Spec{4 * 1024, 1024, 3 * 1024, 1}, + R1Spec{4 * 1024, 1024 + 1, 3 * 1024 - 1, 1}, + R1Spec{16 * 1024, 0, 5, 1}, + R1Spec{16 * 1024, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 0, 5, 1}, + R1Spec{16 * 1024 + 17, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024 + 1, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024, 4 * 1024 - 17, 8 * 1024 - 18, 1}, + R1Spec{64 * 1024, 0, 64 * 1024, 1}, + R1Spec{64 * 1024, 1, 64 * 1024 - 1, 1}, + R1Spec{64 * 1024, 1024, 63 * 1024, 1}, + R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1}, + R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}, +// TODO(b/69425338): This uses too much memory on GPU. +#ifndef XLA_TEST_BACKEND_GPU + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}, +#endif + R1Spec{10, 2, 4, 2}, + R1Spec{10, 0, 10, 2}, + R1Spec{10, 0, 10, 3}, + R1Spec{10, 0, 10, 4}, + R1Spec{10, 0, 10, 5}, + R1Spec{10, 0, 10, 10}, + R1Spec{500, 200, 400, 7}, + R1Spec{4096, 1, 4095, 3}, + R1Spec{2047, 1024 - 24, 1024 + 160, 31}, + R1Spec{2047, 1, 2046, 3 * 128}, + R1Spec{4096, 1024 + 3, 4095, 500}, + R1Spec{8192, 0, 8192, 1024 * 3 + 400} + ), + SliceR1TestDataToString ); +// clang-format on struct R2Spec { int64 input_dim0; diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index 173fb1b0008c9e6edaa1902a5eb3ca5f054a2a67..978a669bcab720bddec5c4bcd0144810ba3c8477 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -21,12 +21,13 @@ limitations under the License. #include #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/regexp.h" namespace xla { namespace { // Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is -// disabled. +// disabled - a sequence of regexps. using ManifestT = std::unordered_map>; ManifestT ReadManifest() { @@ -66,9 +67,6 @@ ManifestT ReadManifest() { string PrependDisabledIfIndicated(const string& test_case_name, const string& test_name) { - // TODO(leary): this code reads the manifest for every test case instantiated - // in every file. Consider switching to a singleton or using a compile-time - // genrule instead. ManifestT manifest = ReadManifest(); // First try full match: test_case_name.test_name @@ -83,11 +81,13 @@ string PrependDisabledIfIndicated(const string& test_case_name, } } + // Expect a full match vs. one of the platform regexps to disable the test. const std::vector& disabled_platforms = it->second; string platform_string = XLA_PLATFORM; - if (std::find(disabled_platforms.begin(), disabled_platforms.end(), - platform_string) != disabled_platforms.end()) { - return "DISABLED_" + test_name; + for (const auto& s : disabled_platforms) { + if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) { + return "DISABLED_" + test_name; + } } // We didn't hit in the disabled manifest entries, so don't disable it. diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 3878ac1013ef1459cbe3c92a48fc6149b6a4948e..28a2d0198a707cec1aa5e0fbed341ee9b2a927f7 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -66,8 +66,10 @@ limitations under the License. namespace xla { -// Reads a disabled manifest file (and retains it as a singleton) to resolve -// whether test cases should be disabled on a particular platform. +// Reads a disabled manifest file to resolve whether test cases should be +// disabled on a particular platform. For a test that should be disabled, +// returns DISABLED_ prepended to its name; otherwise returns the test name +// unmodified. string PrependDisabledIfIndicated(const string& test_case_name, const string& test_name); @@ -96,7 +98,8 @@ string PrependDisabledIfIndicated(const string& test_case_name, test_name)::test_info_ = \ ::testing::internal::MakeAndRegisterTestInfo( \ #test_case_name, \ - PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \ + ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ + .c_str(), \ nullptr, nullptr, \ ::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \ parent_class::SetUpTestCase, parent_class::TearDownTestCase, \ @@ -135,7 +138,8 @@ string PrependDisabledIfIndicated(const string& test_case_name, ::testing::internal::CodeLocation(__FILE__, __LINE__)) \ ->AddTestPattern( \ #test_case_name, \ - PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \ + ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ + .c_str(), \ new ::testing::internal::TestMetaFactory()); \ return 0; \ diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..56859542a94d82aeb783c06b9c4eecf2bde5bade --- /dev/null +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -0,0 +1,277 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" + +namespace xla { + +namespace { + +template +void PopulateWithRandomFloatingPointData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + std::minstd_rand0 engine; + std::uniform_real_distribution generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); +} + +// The standard library does not have a case for bfloat16, unsurprisingly, so we +// handle that one specially. +template <> +void PopulateWithRandomFloatingPointData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), BF16); + std::minstd_rand0 engine; + std::uniform_real_distribution generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return static_cast(generator(engine)); + })); +} + +template +void PopulateWithRandomIntegralData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + std::minstd_rand0 engine; + std::uniform_int_distribution generator( + std::numeric_limits::lowest(), std::numeric_limits::max()); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); +} + +// Matches binary addition computations. +bool LooksLikeSum(const HloComputation& computation) { + const HloInstruction* const root = computation.root_instruction(); + return root->opcode() == HloOpcode::kAdd && + computation.num_parameters() == 2 && + root->operand(0)->opcode() == HloOpcode::kParameter && + root->operand(1)->opcode() == HloOpcode::kParameter && + root->operand(0) != root->operand(1); +} + +// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition, +// which requires an init_value of 0 rather than a random value. +bool NeedsZeroInitValue(const HloUse& use) { + const HloInstruction* const instruction = use.instruction; + const HloOpcode opcode = instruction->opcode(); + const int64 op_num = use.operand_number; + return ( + ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && + op_num == 1 && LooksLikeSum(*instruction->to_apply())) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2 && + LooksLikeSum(*instruction->scatter()))); +} + +// Generate random values that are constrained to the input_shape minus the +// output_shape so as not to produce wrapping slices, for instance. +std::unique_ptr MakeRandomNonwrappingSliceIndex( + const Shape& input_shape, const Shape& slice_shape) { + const int64 rank = ShapeUtil::Rank(input_shape); + std::vector start_indices(rank); + std::minstd_rand0 engine; + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution generator(0, upper_bound); + start_indices[i] = generator(engine); + } + return Literal::CreateR1(start_indices); +} + +// Use dataflow analysis on each parameter to see if there are uses that would +// be problematic when generating input data. Returns the list of instructions +// that correspond to their uses. +// +// Should be paired with the CreateLiteralForConstrainedUses() function below. +std::vector FindConstrainedUses( + const HloDataflowAnalysis& dataflow, const HloInstruction& param) { + std::vector constrained_uses; + for (const auto& pair : dataflow.GetInstructionValueSet(¶m)) { + const HloValue& value = dataflow.GetUniqueValueAt(¶m, pair.first); + for (const HloUse& use : value.uses()) { + HloInstruction* instruction = use.instruction; + const HloOpcode opcode = instruction->opcode(); + const int64 op_num = use.operand_number; + if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) || + (opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) { + constrained_uses.push_back(instruction); + } else if (opcode == HloOpcode::kFusion) { + const HloInstruction* const to_analyze = + instruction->fused_parameter(op_num); + auto fused_uses = FindConstrainedUses(dataflow, *to_analyze); + constrained_uses.insert(constrained_uses.end(), fused_uses.begin(), + fused_uses.end()); + } else if (NeedsZeroInitValue(use)) { + constrained_uses.push_back(instruction); + } + } + } + return constrained_uses; +} + +// Given a parameter, generate a random Literal to use as input if there exist +// no constrained uses in the dataflow graph. If such constraints exist, +// generate a constrained literal (either bounded in the case of indices, or +// zero in the case of init_values for reductions). +StatusOr> CreateLiteralForConstrainedUses( + const tensorflow::gtl::ArraySlice constrained_uses, + const HloInstruction& param) { + HloInstruction* needs_index = nullptr; + HloInstruction* needs_zero = nullptr; + for (HloInstruction* use : constrained_uses) { + switch (use->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + TF_RET_CHECK(ShapeUtil::Equal(param.shape(), use->operand(0)->shape())); + if (needs_index != nullptr && + !ShapeUtil::Equal(needs_index->shape(), use->shape())) { + return Unimplemented( + "Conflicting operand generation slice index constraints\n"); + } + needs_index = use; + break; + + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + needs_zero = use; + break; + + default: + return Unimplemented( + "Constrained operand generation not implemented for %s.", + use->ToString().c_str()); + } + } + if (needs_index != nullptr && needs_zero != nullptr) { + return Unimplemented( + "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " + "zero: %s\n", + needs_index->ToString().c_str(), needs_zero->ToString().c_str()); + } + if (needs_index != nullptr) { + return MakeRandomNonwrappingSliceIndex(param.shape(), needs_index->shape()); + } else if (needs_zero != nullptr) { + return Literal::CreateFromShape(param.shape()); + } else { + return MakeFakeLiteral(param.shape()); + } +} + +// Given a module entry parameter, use the dataflow analysis to see if a +// special case literal must be created, or if we can generate fake data. +StatusOr> MakeConstrainedArgument( + const HloDataflowAnalysis& dataflow, const HloInstruction& param) { + const auto constrained_uses = FindConstrainedUses(dataflow, param); + return CreateLiteralForConstrainedUses(constrained_uses, param); +} + +} // namespace + +StatusOr> MakeFakeLiteral(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + std::vector> elements; + for (const Shape& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr element, + MakeFakeLiteral(element_shape)); + elements.push_back(std::move(element)); + } + return Literal::MakeTupleOwned(std::move(elements)); + } + std::unique_ptr literal = Literal::CreateFromShape(shape); + switch (shape.element_type()) { + case BF16: + PopulateWithRandomFloatingPointData(literal.get()); + break; + case F32: + PopulateWithRandomFloatingPointData(literal.get()); + break; + case F64: + PopulateWithRandomFloatingPointData(literal.get()); + break; + case S8: + PopulateWithRandomIntegralData(literal.get()); + break; + case U8: + PopulateWithRandomIntegralData(literal.get()); + break; + case S16: + PopulateWithRandomIntegralData(literal.get()); + break; + case U16: + PopulateWithRandomIntegralData(literal.get()); + break; + case S32: + PopulateWithRandomIntegralData(literal.get()); + break; + case U32: + PopulateWithRandomIntegralData(literal.get()); + break; + case S64: + PopulateWithRandomIntegralData(literal.get()); + break; + case U64: + PopulateWithRandomIntegralData(literal.get()); + break; + case PRED: { + std::uniform_int_distribution generator(0, 1); + std::minstd_rand0 engine; + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); + break; + } + default: + return Unimplemented("Unsupported type for fake literal generation: %s", + ShapeUtil::HumanString(shape).c_str()); + } + return std::move(literal); +} + +StatusOr>> MakeFakeArguments( + HloModule* const module) { + TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module)); + const auto params = module->entry_computation()->parameter_instructions(); + std::vector> arguments(params.size()); + for (int i = 0; i < params.size(); ++i) { + TF_ASSIGN_OR_RETURN(arguments[i], + MakeConstrainedArgument(*dataflow, *params[i])); + } + return std::move(arguments); +} + +Status VerifyHloModule(const perftools::gputools::Platform& platform, + HloModule* const module) { + return HloVerifier( + std::bind( + &TransferManager::GetByteSizeRequirement, + TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(), + std::placeholders::_1)) + .Run(module) + .status(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index f3a522b05ebae4f1f86d6d7ddbac6e1749d3e286..0fb024ffb074f1c90b75022bc7f5a8b58b03c0c2 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -23,12 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/platform.h" namespace xla { -namespace test_utils { // A class which generates pseudorandom numbers of a given type within a given // range. Not cryptographically secure and likely not perfectly evenly @@ -53,63 +54,23 @@ class PseudorandomGenerator { std::mt19937 generator_; }; -// Convenience function for creating a rank-2 array with arbitrary layout. -template -std::unique_ptr CreateR2LiteralWithLayout( - std::initializer_list> values, - tensorflow::gtl::ArraySlice minor_to_major) { - auto literal = MakeUnique(); - const int64 d0 = values.size(); - const int64 d1 = values.begin()->size(); - literal.get()->PopulateWithValue(0, {d0, d1}); - *literal->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout(minor_to_major); - TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); +// Generates fake data in a literal of the given shape, or returns an error +// status if the element type is currently unhandled for fake data generation. +StatusOr> MakeFakeLiteral(const Shape& shape); - int64 dim0 = 0; - for (auto inner_list : values) { - int64 dim1 = 0; - for (auto value : inner_list) { - literal.get()->Set({dim0, dim1}, value); - ++dim1; - } - ++dim0; - } - return literal; -} +// Generates a vector of arguments containing fake data. The number, shape and +// layout of the arguments is appropriate for given HLO module. +// +// Will handle special cases such as making sure that indices used for dynamic +// slices are bounded, reduces that call adds use 0 as an init value, etc. +StatusOr>> MakeFakeArguments( + HloModule* const module); -// Convenience function for creating a rank-3 array with arbitrary layout. -template -std::unique_ptr CreateR3LiteralWithLayout( - std::initializer_list>> - values, - tensorflow::gtl::ArraySlice minor_to_major) { - auto literal = MakeUnique(); - const int64 d0 = values.size(); - const int64 d1 = values.begin()->size(); - const int64 d2 = values.begin()->begin()->size(); - literal.get()->PopulateWithValue(0, {d0, d1, d2}); - *literal->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout(minor_to_major); - TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); - - int64 dim0 = 0; - for (auto inner_list : values) { - int64 dim1 = 0; - for (auto inner_inner_list : inner_list) { - int64 dim2 = 0; - for (auto value : inner_inner_list) { - literal.get()->Set({dim0, dim1, dim2}, value); - ++dim2; - } - ++dim1; - } - ++dim0; - } - return literal; -} +// Check that a given module satisfies various constraints before trying to +// execute it. +Status VerifyHloModule(const perftools::gputools::Platform& platform, + HloModule* const module); -} // namespace test_utils } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f2a64749482e5f5a8c5d72034fb7a4eee07baf48 --- /dev/null +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -0,0 +1,215 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class TransferManagerTest : public LocalClientTestBase { + protected: + TransferManagerTest() + : shape_size_fn_([this](const Shape& shape) { + return transfer_manager_->GetByteSizeRequirement(shape); + }) {} + + ~TransferManagerTest() override = default; + + std::unique_ptr AllocateDeviceBuffer(const Shape& shape) { + return ScopedShapedBuffer::Allocate( + shape, GetOrCreateAllocator(local_client_->platform()), + /*device_ordinal=*/0, shape_size_fn_) + .ValueOrDie(); + } + + private: + std::function shape_size_fn_; +}; + +XLA_TEST_F(TransferManagerTest, TransferR0U32) { + std::unique_ptr literal = Literal::CreateR0(42); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR0Equal(42, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferR1F32) { + std::unique_ptr literal = + Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, + *result); +} + +XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { + std::vector test_vector(1024 * 1024); + std::iota(test_vector.begin(), test_vector.end(), 0); + std::unique_ptr literal = Literal::CreateR1(test_vector); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR1Equal(test_vector, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferR1U8) { + const char* test_string = "0123456789abcdef"; + std::unique_ptr literal = Literal::CreateR1U8(test_string); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + EXPECT_EQ(result->u8s_string(), test_string); +} + +XLA_TEST_F(TransferManagerTest, TransferR2F32) { + std::unique_ptr literal = + Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); +} + +XLA_TEST_F(TransferManagerTest, + TransferR2F32AndChangeLayoutTransferringToDevice) { + std::unique_ptr literal = Literal::CreateR2WithLayout( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1})); + const Shape ondevice_shape = + ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); + auto device_buffer = AllocateDeviceBuffer(ondevice_shape); + + // Round trip literal through device. Set the on-device layout to something + // different than the literal layout. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + EXPECT_FALSE( + LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferTuple) { + std::unique_ptr literal = Literal::MakeTuple( + {Literal::CreateR0(123.0f).get(), + Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { + std::unique_ptr literal = Literal::MakeTuple({}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { + std::unique_ptr literal = Literal::MakeTuple( + {Literal::CreateR0(123.0f).get(), + Literal::MakeTuple( + {Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + .get(), + Literal::CreateR1({-10.0f, 123.0f}).get()}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 4920f17a7ed21d587c15b8deac550d5e5bb566c9..65489cfff19c8fecbdead8a7e295bf9cca56038f 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -180,7 +180,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } -XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { +// TODO(b/68395210): GPU does not tolerate ambiguous top-level buffers. +XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenPredTuples)) { ComputationBuilder b(client_, TestName()); ComputationDataHandle v1, v2; @@ -444,5 +445,61 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { ComputeAndCompareR1(&builder, expected, arguments, ErrorSpec(1e-5)); } +XLA_TEST_F(TupleTest, ComplexTuples) { + ComputationBuilder builder(client_, TestName()); + { + Shape c64r0 = ShapeUtil::MakeShape(C64, {}); + Shape c64r1 = ShapeUtil::MakeShape(C64, {2}); + Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2}); + Shape arg0_shape = ShapeUtil::MakeTupleShape( + {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})}); + auto input0 = builder.Parameter(0, arg0_shape, "input0"); + auto t0 = builder.GetTupleElement(input0, 0); + auto t1 = builder.GetTupleElement(input0, 1); + auto t10 = builder.GetTupleElement(t1, 0); + auto t11 = builder.GetTupleElement(t1, 1); + auto sum = builder.Add(builder.Add(t10, t11, {1}), t0); + auto input1 = builder.Parameter(1, c64r1, "input1"); + auto prod = builder.Mul(input1, sum, {1}); + builder.Tuple({builder.Tuple({prod, sum}), + builder.ConstantR0({123, 456})}); + } + + std::unique_ptr arg0 = + client_ + ->TransferToServer(*Literal::MakeTuple( + {Literal::CreateR0({1, 2}).get(), + Literal::MakeTuple( + {Literal::CreateR1({{10, 20}, {30, 40}}).get(), + Literal::CreateR2( + {{{100, 200}, {300, 400}}, + {{1000, 2000}, {3000, 4000}}, + {{10000, 20000}, {30000, 40000}}}) + .get()}) + .get()})) + .ConsumeValueOrDie(); + std::unique_ptr arg1 = + client_ + ->TransferToServer(*Literal::CreateR1({{1, 2}, {1, -2}})) + .ConsumeValueOrDie(); + auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, + {{1011, 2022}, {3031, 4042}}, + {{10011, 20022}, {30031, 40042}}}); + auto prod = Literal::CreateFromShape(sum->shape()); + ASSERT_TRUE(prod->Populate( + [&sum](tensorflow::gtl::ArraySlice indexes) { + return sum->Get(indexes) * + (indexes[indexes.size() - 1] == 0 + ? complex64(1, 2) + : complex64(1, -2)); + }) + .ok()); + auto expected = + Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(), + Literal::CreateR0({123, 456}).get()}); + ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, + error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 49f673f5f0bf9b844ab4030383784208b4e2c58a..0b3430ee1ee515c2c98c64a947b7a7021c04f22b 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -357,8 +357,7 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } -// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. -TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { +TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -411,8 +410,7 @@ TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } -// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. -TEST_F(WhileTest, DISABLED_WhileWithPermutationAndVectorResult) { +TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -913,8 +911,7 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { } } -// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) { +TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); ComputationBuilder outer(client_, "outer"); @@ -950,8 +947,7 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) { ErrorSpec(1e-6)); } -// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithBroadcast)) { +TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); ComputationBuilder outer(client_, "outer"); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 759921dce5acf3cd23a121776f3ab0731c9bb623..091fa0c3ec807a66449eca0bfbb141285b8eb532 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -88,6 +88,7 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD index ce936af6c3376387c1ed9fa48da23b8af537f6e5..97aacf6b39f83978e732060817cd93ede81ca782 100644 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -34,9 +34,9 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", ], diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md index 2c864d77a20207bab7c72b207b31c9b886441e9b..6232967f5f04cbf316d985357ae84c28335531e2 100644 --- a/tensorflow/compiler/xla/tools/parser/README.md +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -43,14 +43,22 @@ operand : shape name ; -extra_attributes +attributes : /*empty*/ - | ',' extra_attribute - | ',' extra_attribute extra_attributes + | ',' attribute + | ',' attribute attributes ; -extra_attribute +attribute : attribute_name attribute_value ; +attribute_value + : kInt + | kName + | [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} /*dim_labels_pattern*/ + | [0-9]+(x[0-9]+)+ /*dxd_pattern*/ + | [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* /*pad_pattern*/ + | '{' sub_attributes '}' + ; param_list : '(' param_list1 ')' diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index d104ff34601216bbaf5d5c068e00a7191a9b3b17..459d511e90d87537f3a3404b82df7b28b1fe08bd 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" namespace xla { @@ -122,7 +123,7 @@ TokKind HloLexer::LexToken() { current_ptr_++; return TokKind::kArrow; } - return LexDigitOrNegative(); + return LexNumberOrPattern(); case '=': return TokKind::kEqual; case ',': @@ -145,16 +146,21 @@ TokKind HloLexer::LexToken() { return TokKind::kRparen; case '/': return LexComment(); + case '"': + return LexString(); } } } -// Lex a shape, name, keyword, or opcode. +// Lex a shape, name, keyword, attribute name, the dim labels pattern, and +// other identifiers. +// // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: // keyword ::= HloModule, ENTRY, ... -// opcode ::= add, greater-than, ... // attribute_name ::= condition, body, dimensions, ... +// dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} +// identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexIdentifier() { { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); @@ -213,15 +219,19 @@ TokKind HloLexer::LexIdentifier() { #undef KEYWORD - // See if this is an opcode. - auto opcode = StringToHloOpcode(identifier.ToString()); - if (opcode.ok()) { - opcode_val_ = opcode.ValueOrDie(); - return TokKind::kOpcode; + { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 dim_labels_pattern = { + R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; + if (RE2::Consume(&consumable, *dim_labels_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDimLabels; + } } - current_ptr_ = token_start_ + 1; - return TokKind::kError; + str_val_ = identifier.ToString(); + return TokKind::kIdent; } // Lex names after a % character. @@ -240,15 +250,20 @@ TokKind HloLexer::LexPercent() { return TokKind::kError; } -// Lex integer and floating-point values, and -inf. -// int [-]?[0-9]+ -// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) -// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) -// negative inf -inf -TokKind HloLexer::LexDigitOrNegative() { +// Lex integer and floating-point values, -inf, and patterns for dim labels, +// dxd (e.g. 1x2x3), and pad. +// +// fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) +// fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) +// dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} +// dxd_pattern ::= [0-9]+(x[0-9]+)+ +// pad_pattern ::= [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* +// int ::= [-]?[0-9]+ +// negative inf ::= '-inf' +TokKind HloLexer::LexNumberOrPattern() { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); static LazyRE2 float_pattern = { - R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"}; + R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), @@ -256,6 +271,30 @@ TokKind HloLexer::LexDigitOrNegative() { return TokKind::kDecimal; } + static LazyRE2 dim_labels_pattern = { + R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; + static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"}; + static LazyRE2 pad_pattern = { + R"([0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*)"}; + + if (RE2::Consume(&consumable, *dim_labels_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDimLabels; + } + + if (RE2::Consume(&consumable, *dxd_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDxD; + } + + if (RE2::Consume(&consumable, *pad_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kPad; + } + static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); @@ -273,18 +312,43 @@ TokKind HloLexer::LexDigitOrNegative() { return TokKind::kError; } -StringPiece HloLexer::GetCurrentLine() const { - const char* start = token_start_; - const char* end = current_ptr_; - if (!CanDereference(start) || !CanDereference(end)) { - return "LINE OUT OF RANGE"; +std::pair HloLexer::GetLineAndColumn(LocTy location) const { + unsigned line_no = 1; + const char* start = buf_.begin(); + const char* ptr = start; + if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) && + line_no_cache_.last_query <= location) { + ptr = line_no_cache_.last_query; + line_no = line_no_cache_.line_no_of_query; + } + for (; ptr != location; ptr++) { + if (*ptr == '\n') { + line_no++; + } } - while (start > buf_.begin() && *start != '\n') { - start--; + + // Update the line number cache. + line_no_cache_.last_query = ptr; + line_no_cache_.line_no_of_query = line_no; + size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); + if (line_offset == StringPiece::npos) { + line_offset = 0; } - while (end < buf_.end() && *end != '\n') { - end++; + return {line_no, ptr - start - line_offset}; +} + +StringPiece HloLexer::GetLine(LocTy loc) const { + if (!CanDereference(loc)) { + return "LINE OUT OF RANGE"; } + size_t line_start = + StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); + const char* start = line_start == StringPiece::npos + ? buf_.begin() + : buf_.begin() + line_start + 1; + size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); + const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end; + return StringPieceFromPointers(start, end); } @@ -298,6 +362,25 @@ TokKind HloLexer::LexComment() { return TokKind::kError; } +// Lexes quoted string with escaping characters. If matched, the quoted string +// will be unescaped and stored to str_val_. +TokKind HloLexer::LexString() { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; + if (RE2::Consume(&consumable, *escaping_pattern)) { + current_ptr_ = consumable.begin(); + StringPiece raw = + StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); + string error; + if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { + LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; + return TokKind::kError; + } + return TokKind::kString; + } + return TokKind::kError; +} + string TokKindToString(TokKind kind) { switch (kind) { case TokKind::kEof: @@ -350,10 +433,18 @@ string TokKindToString(TokKind kind) { return "kName"; case TokKind::kAttributeName: return "kAttributeName"; + case TokKind::kDimLabels: + return "kDimLabels"; + case TokKind::kDxD: + return "kDxD"; + case TokKind::kPad: + return "kPad"; + case TokKind::kIdent: + return "kIdent"; + case TokKind::kString: + return "kString"; case TokKind::kShape: return "kShape"; - case TokKind::kOpcode: - return "kOpcode"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h index 3b9efcb92d074a234868a12b8f4dc5db867ea1ec..27880b9b8afbfa58abfedc3b2cecd5236b78a6d6 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" @@ -37,11 +37,17 @@ class HloLexer { } TokKind Lex() { return current_kind_ = LexToken(); } + TokKind GetKind() const { return current_kind_; } string GetStrVal() const { switch (GetKind()) { case TokKind::kName: case TokKind::kAttributeName: + case TokKind::kDimLabels: + case TokKind::kDxD: + case TokKind::kPad: + case TokKind::kString: + case TokKind::kIdent: return str_val_; default: LOG(FATAL) << "This token does not have string value"; @@ -51,10 +57,6 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - HloOpcode GetOpcodeVal() const { - CHECK(GetKind() == TokKind::kOpcode); - return opcode_val_; - } int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; @@ -64,8 +66,16 @@ class HloLexer { return decimal_val_; } - // Returns the line of text that is currently being lexed. - tensorflow::StringPiece GetCurrentLine() const; + typedef const char* LocTy; + + // Returns the location of the current token. + LocTy GetLoc() const { return token_start_; } + + // Returns the line and column of a location in the buffer. + std::pair GetLineAndColumn(LocTy location) const; + + // Returns the whole line given the location. + tensorflow::StringPiece GetLine(LocTy loc) const; private: // Returns the current character. If it's neither the end of input buffer nor @@ -92,8 +102,9 @@ class HloLexer { TokKind LexPercent(); TokKind LexShape(); TokKind LexConstant(); - TokKind LexDigitOrNegative(); + TokKind LexNumberOrPattern(); TokKind LexComment(); + TokKind LexString(); const tensorflow::StringPiece buf_; const char* current_ptr_; @@ -103,9 +114,15 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - HloOpcode opcode_val_; int64 int64_val_; double decimal_val_; + + struct LineNoCacheTy { + const char* last_query; + unsigned line_no_of_query; + }; + // This caches the line number of the previous query. + mutable LineNoCacheTy line_no_cache_{nullptr, 0}; }; } // namespace tools diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 6c2e37e3b5cdd73157279fb171d3332aa9854184..4f67ed23801f9b8eb50b7c959f0796ca4e6c578d 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -28,6 +29,9 @@ namespace tools { namespace { using tensorflow::StringPiece; +using tensorflow::gtl::optional; +using tensorflow::str_util::Split; +using tensorflow::str_util::SplitAndParseAsInts; using tensorflow::strings::Printf; using tensorflow::strings::StrAppend; using tensorflow::strings::StrCat; @@ -37,6 +41,8 @@ const double kF16max = 65504; // Parser for the HloModule::ToString() format text. class HloParser { public: + using LocTy = HloLexer::LocTy; + explicit HloParser(StringPiece str, const HloModuleConfig& config) : lexer_(str), config_(config) {} @@ -57,7 +63,6 @@ class HloParser { bool ParseInstructionList(HloComputation::Builder* builder, string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); - bool ParseSharding(HloInstruction* instruction); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); @@ -78,16 +83,102 @@ class HloParser { bool ParseOperands(std::vector* operands, const int expected_size); - template - bool ParseExtraAttribute(T* value, const string& expected_attribute); - template - bool ParseAttributeValue(T* value); + // Describes the start, limit, and stride on every dimension of the operand + // being sliced. + struct SliceRanges { + std::vector starts; + std::vector limits; + std::vector strides; + }; + + // Types of attributes. + enum class AttrTy { + kInt64, + kInt32, + kFloat, + kString, + kBracedInt64List, + kHloComputation, + kWindow, + kConvolutionDimensionNumbers, + kSharding, + kInstructionList, + kSliceRanges, + kPaddingConfig, + kMetadata, + kFusionKind, + kDistribution, + }; + + struct AttrConfig { + bool required; // whether it's required or optional + AttrTy attr_type; // what type it is + void* result; // where to store the parsed result. + }; + + // attributes ::= (',' attribute)* + // + // Parses attributes given names and configs of the attributes. Each parsed + // result is passed back through the result pointer in corresponding + // AttrConfig. Note that the result pointer must point to a optional typed + // variable which outlives this function. Returns false on error. You should + // not use the any of the results if this function failed. + // + // Example usage: + // + // std::unordered_map attrs; + // optional foo; + // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo}; + // optional bar; + // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar}; + // if (!ParseAttributes(attrs)) { + // return false; // Do not use 'foo' 'bar' if failed. + // } + // // Do something with 'bar'. + // if (foo) { // If attr foo is seen, do something with 'foo'. } + // + bool ParseAttributes(const std::unordered_map& attrs); + + // sub_attributes ::= '{' (','? attribute)* '}' + // + // Usage is the same as ParseAttributes. See immediately above. + bool ParseSubAttributes(const std::unordered_map& attrs); + + // Parses one attribute. If it has already been seen, return error. Returns + // true and adds to seen_attrs on success. + // + // Do not call this except in ParseAttributes or ParseSubAttributes. + bool ParseAttributeHelper(const std::unordered_map& attrs, + std::unordered_set* seen_attrs); + + // Parses a name and finds the corresponding hlo computation. + bool ParseComputationName(HloComputation** value); + // Parses a list of names and finds the corresponding hlo instructions. + bool ParseInstructionNames(std::vector* instructions); + bool ParseWindow(Window* window); + bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); + bool ParsePaddingConfig(PaddingConfig* padding); + bool ParseMetadata(OpMetadata* metadata); + bool ParseSharding(OpSharding* sharding); + bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + + // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. + bool ParseDxD(const string& name, std::vector* result); + // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. + bool ParseWindowPad(std::vector>* pad); + + bool ParseSliceRanges(SliceRanges* result); + bool ParseInt64List(const TokKind start, const TokKind end, + const TokKind delim, std::vector* result); bool ParseParamList(); bool ParseName(string* result); bool ParseAttributeName(string* result); + bool ParseString(string* result); bool ParseShape(Shape* result); bool ParseOpcode(HloOpcode* result); + bool ParseFusionKind(HloInstruction::FusionKind* result); + bool ParseRandomDistribution(RandomDistribution* result); bool ParseInt64(int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -95,6 +186,7 @@ class HloParser { // Logs the current parsing line and the given message. Always returns false. bool TokenError(StringPiece msg); + bool Error(LocTy loc, StringPiece msg); // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. @@ -105,10 +197,12 @@ class HloParser { // Adds the instruction to the pool. Returns false and emits an error if the // instruction already exists. - bool AddInstruction(const string& name, HloInstruction* instruction); + bool AddInstruction(const string& name, HloInstruction* instruction, + LocTy name_loc); // Adds the computation to the pool. Returns false and emits an error if the // computation already exists. - bool AddComputation(const string& name, HloComputation* computation); + bool AddComputation(const string& name, HloComputation* computation, + LocTy name_loc); // The map from the instruction name to the instruction. This does not own the // instructions. @@ -121,15 +215,25 @@ class HloParser { std::vector error_; }; -bool HloParser::TokenError(StringPiece msg) { - const string error = - StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; token ", - TokKindToString(lexer_.GetKind()), "; ", msg); - VLOG(1) << "TokenError: " << error; - error_.push_back(error); +bool HloParser::Error(LocTy loc, StringPiece msg) { + auto line_col = lexer_.GetLineAndColumn(loc); + const unsigned line = line_col.first; + const unsigned col = line_col.second; + std::vector error_lines; + error_lines.push_back( + StrCat("was parsing ", line, ":", col, ": error: ", msg)); + error_lines.push_back(lexer_.GetLine(loc).ToString()); + error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); + + error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + VLOG(1) << "Error: " << error_.back(); return false; } +bool HloParser::TokenError(StringPiece msg) { + return Error(lexer_.GetLoc(), msg); +} + bool HloParser::Run() { lexer_.Lex(); return ParseHloModule(); @@ -167,6 +271,7 @@ bool HloParser::ParseComputations() { bool HloParser::ParseComputation() { const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY); string name; + LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name)) { return false; } @@ -187,6 +292,7 @@ bool HloParser::ParseComputation() { LOG(FATAL) << "instruction " << root_name << " was marked as ROOT but the parser has not seen it before"; } + // Now root can be either an existing instruction or a nullptr. If it's a // nullptr, the implementation of Builder will set the last instruction as // root instruction. @@ -194,7 +300,26 @@ bool HloParser::ParseComputation() { is_entry_computation ? module_->AddEntryComputation(builder->Build(root)) : module_->AddEmbeddedComputation(builder->Build(root)); - return AddComputation(name, computation); + + // The parameters and result layouts were set to default layout. Here we set + // the layouts to what the hlo text says. + if (is_entry_computation) { + for (int i = 0; i < computation->num_parameters(); i++) { + const Shape& param_shape = computation->parameter_instruction(i)->shape(); + if (param_shape.has_layout()) { + module_->mutable_entry_computation_layout() + ->mutable_parameter_layout(i) + ->ResetLayout(param_shape.layout()); + } + } + const Shape& result_shape = computation->root_instruction()->shape(); + if (result_shape.has_layout()) { + module_->mutable_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(result_shape.layout()); + } + } + return AddComputation(name, computation, name_loc); } // instruction_list ::= '{' instruction_list1 '}' @@ -214,7 +339,7 @@ bool HloParser::ParseInstructionList(HloComputation::Builder* builder, "expects '}' at the end of instruction list."); } -// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)* +// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* bool HloParser::ParseInstruction(HloComputation::Builder* builder, string* root_name) { string name; @@ -222,6 +347,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloOpcode opcode; std::vector operands; bool is_root = EatIfPresent(TokKind::kw_ROOT); + + const LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name) || !ParseToken(TokKind::kEqual, "expects '=' in instruction") || !ParseShape(&shape) || !ParseOpcode(&opcode)) { @@ -230,6 +357,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (is_root) { *root_name = name; } + + // Add optional attributes. + std::unordered_map attrs; + optional sharding; + attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; + optional> predecessors; + attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, + &predecessors}; + optional metadata; + attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -237,7 +375,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number) || - !ParseToken(TokKind::kRparen, "expects ')' after parameter number")) { + !ParseToken(TokKind::kRparen, "expects ')' after parameter number") || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -249,7 +388,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || - !ParseToken(TokKind::kRparen, "expects ')' after constant literal")) { + !ParseToken(TokKind::kRparen, "expects ')' after constant literal") || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -275,7 +415,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kSin: case HloOpcode::kSort: case HloOpcode::kTanh: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -305,7 +446,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: { - if (!ParseOperands(&operands, /*expected_size=*/2)) { + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateBinary( @@ -315,7 +457,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: { - if (!ParseOperands(&operands, /*expected_size=*/3)) { + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateTernary( @@ -324,23 +467,34 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } // Other supported ops. case HloOpcode::kConvert: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( HloInstruction::CreateConvert(shape, operands[0])); break; } + case HloOpcode::kBitcastConvert: { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateBitcastConvert(shape, operands[0])); + break; + } case HloOpcode::kCrossReplicaSum: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands[0])); + HloInstruction::CreateCrossReplicaSum(shape, operands)); break; } case HloOpcode::kReshape: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -348,7 +502,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kTuple: { - if (!ParseOperands(&operands)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = @@ -356,130 +510,478 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kWhile: { - HloComputation* condition; - HloComputation* body; + optional condition; + optional body; + attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation, + &condition}; + attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&condition, - /*expected_attribute=*/"condition") || - !ParseExtraAttribute(&body, /*expected_attribute=*/"body")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateWhile( - shape, condition, body, /*init=*/operands[0])); + shape, *condition, *body, /*init=*/operands[0])); break; } case HloOpcode::kRecv: { - int64 channel_id; + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/0) || - !ParseExtraAttribute(&channel_id, - /*expected_attribute=*/"channel_id")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateRecv(shape, channel_id)); + HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id)); + break; + } + case HloOpcode::kRecvDone: { + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + if (channel_id != operands[0]->channel_id()) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0])); break; } case HloOpcode::kSend: { - int64 channel_id; + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&channel_id, - /*expected_attribute=*/"channel_id")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateSend(operands[0], channel_id)); + HloInstruction::CreateSend(operands[0], *channel_id)); + break; + } + case HloOpcode::kSendDone: { + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + if (channel_id != operands[0]->channel_id()) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateSendDone(operands[0])); break; } case HloOpcode::kGetTupleElement: { - int64 index; + optional index; + attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&index, /*expected_attribute=*/"index")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateGetTupleElement(shape, operands[0], index)); + HloInstruction::CreateGetTupleElement(shape, operands[0], *index)); break; } case HloOpcode::kCall: { - HloComputation* to_apply; - if (!ParseOperands(&operands) || - !ParseExtraAttribute(&to_apply, - /*expected_attribute=*/"to_apply")) { + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCall(shape, operands, to_apply)); + HloInstruction::CreateCall(shape, operands, *to_apply)); + break; + } + case HloOpcode::kReduceWindow: { + optional reduce_computation; + optional window; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &reduce_computation}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + if (!window) { + window.emplace(); + } + instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow( + shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window, + *reduce_computation)); + break; + } + case HloOpcode::kConvolution: { + optional window; + optional dnums; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; + attrs["dim_labels"] = {/*required=*/true, + AttrTy::kConvolutionDimensionNumbers, &dnums}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + if (!window) { + window.emplace(); + } + instruction = builder->AddInstruction(HloInstruction::CreateConvolve( + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); + break; + } + case HloOpcode::kBroadcast: { + optional> broadcast_dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &broadcast_dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateBroadcast( + shape, operands[0], *broadcast_dimensions)); + break; + } + case HloOpcode::kConcatenate: { + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs) || + dimensions->size() != 1) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateConcatenate( + shape, operands, dimensions->at(0))); + break; + } + case HloOpcode::kMap: { + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateMap(shape, operands, *to_apply)); + break; + } + case HloOpcode::kReduce: { + optional reduce_computation; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &reduce_computation}; + optional> dimensions_to_reduce; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions_to_reduce}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateReduce( + shape, /*operand=*/operands[0], /*init_value=*/operands[1], + *dimensions_to_reduce, *reduce_computation)); + break; + } + case HloOpcode::kReverse: { + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateReverse(shape, operands[0], *dimensions)); + break; + } + case HloOpcode::kSelectAndScatter: { + optional select; + attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select}; + optional scatter; + attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter}; + optional window; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + if (!window) { + window.emplace(); + } + instruction = + builder->AddInstruction(HloInstruction::CreateSelectAndScatter( + shape, /*operand=*/operands[0], *select, *window, + /*source=*/operands[1], /*init_value=*/operands[2], *scatter)); + break; + } + case HloOpcode::kSlice: { + optional slice_ranges; + attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateSlice( + shape, operands[0], slice_ranges->starts, slice_ranges->limits, + slice_ranges->strides)); + break; + } + case HloOpcode::kDynamicSlice: { + optional> dynamic_slice_sizes; + attrs["dynamic_slice_sizes"] = { + /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice( + shape, /*operand=*/operands[0], /*start_indices=*/operands[1], + *dynamic_slice_sizes)); + break; + } + case HloOpcode::kDynamicUpdateSlice: { + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + shape, /*operand=*/operands[0], /*update=*/operands[1], + /*start_indices=*/operands[2])); + break; + } + case HloOpcode::kTranspose: { + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateTranspose(shape, operands[0], *dimensions)); + break; + } + case HloOpcode::kBatchNormTraining: { + optional epsilon; + attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; + optional feature_index; + attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, + &feature_index}; + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateBatchNormTraining( + shape, /*operand=*/operands[0], /*scale=*/operands[1], + /*offset=*/operands[2], *epsilon, *feature_index)); + break; + } + case HloOpcode::kBatchNormInference: { + optional epsilon; + attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; + optional feature_index; + attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, + &feature_index}; + if (!ParseOperands(&operands, /*expected_size=*/5) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateBatchNormInference( + shape, /*operand=*/operands[0], /*scale=*/operands[1], + /*offset=*/operands[2], /*mean=*/operands[3], + /*variance=*/operands[4], *epsilon, *feature_index)); + break; + } + case HloOpcode::kBatchNormGrad: { + optional epsilon; + attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; + optional feature_index; + attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, + &feature_index}; + if (!ParseOperands(&operands, /*expected_size=*/5) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad( + shape, /*operand=*/operands[0], /*scale=*/operands[1], + /*mean=*/operands[2], /*variance=*/operands[3], + /*grad_output=*/operands[4], *epsilon, *feature_index)); + break; + } + case HloOpcode::kPad: { + optional padding; + attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreatePad( + shape, operands[0], /*padding_value=*/operands[1], *padding)); + break; + } + case HloOpcode::kFusion: { + optional fusion_computation; + attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation, + &fusion_computation}; + optional fusion_kind; + attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateFusion( + shape, *fusion_kind, operands, *fusion_computation)); + break; + } + case HloOpcode::kInfeed: { + optional config; + attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; + if (!ParseOperands(&operands, /*expected_size=*/0) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateInfeed(shape, config ? *config : "")); + break; + } + case HloOpcode::kOutfeed: { + optional config; + attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( + shape, operands[0], config ? *config : "")); + break; + } + case HloOpcode::kRng: { + optional distribution; + attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution, + &distribution}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateRng(shape, *distribution, operands)); + break; + } + case HloOpcode::kReducePrecision: { + optional exponent_bits; + optional mantissa_bits; + attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, + &exponent_bits}; + attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, + &mantissa_bits}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateReducePrecision( + shape, operands[0], static_cast(*exponent_bits), + static_cast(*mantissa_bits))); + break; + } + case HloOpcode::kConditional: { + optional true_computation; + optional false_computation; + attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation, + &true_computation}; + attrs["false_computation"] = {/*required=*/true, AttrTy::kHloComputation, + &false_computation}; + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateConditional( + shape, /*pred=*/operands[0], + /*true_computation_arg=*/operands[1], *true_computation, + /*false_computation_arg=*/operands[2], *false_computation)); + break; + } + case HloOpcode::kCustomCall: { + optional custom_call_target; + attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, + &custom_call_target}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target)); break; } - case HloOpcode::kBroadcast: - case HloOpcode::kCustomCall: - case HloOpcode::kConcatenate: - case HloOpcode::kReducePrecision: - case HloOpcode::kConvolution: - case HloOpcode::kMap: - case HloOpcode::kPad: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kReverse: - case HloOpcode::kRng: - case HloOpcode::kSlice: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kTranspose: - case HloOpcode::kFusion: - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kBatchNormGrad: case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } - bool has_sharding = false; - bool has_control = false; - while (EatIfPresent(TokKind::kComma)) { - string attribute_name; - if (!ParseAttributeName(&attribute_name)) { - return TokenError("expects ', sharding=' or ', control-predecessors='"); + // Add common attrs (sharding, control predecessors) to the instruction, if + // they were seen. + if (sharding) { + instruction->set_sharding( + HloSharding::FromProto(sharding.value()).ValueOrDie()); + } + if (predecessors) { + for (auto* pre : *predecessors) { + Status status = pre->AddControlDependencyTo(instruction); + if (!status.ok()) { + return Error(name_loc, StrCat("error adding control dependency for: ", + name, " status: ", status.ToString())); + } } + } + if (metadata) { + instruction->set_metadata(*metadata); + } + return AddInstruction(name, instruction, name_loc); +} // NOLINT(readability/fn_size) + +// ::= '{' (single_sharding | tuple_sharding) '}' +// +// tuple_sharding ::= single_sharding* (',' single_sharding)* +bool HloParser::ParseSharding(OpSharding* sharding) { + // A single sharding starts with '{' and is not followed by '{'. + // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for + // an empty tuple. + if (!ParseToken(TokKind::kLbrace, + "expected '{' to start sharding attribute")) { + return false; + } - if (attribute_name == "sharding") { - // Parse "sharding=". - if (has_sharding) { - return TokenError("expects at most 1 'sharding='"); - } - has_sharding = true; - if (!ParseSharding(instruction)) { - return false; - } - } else if (attribute_name == "control-predecessors") { - // Parse "control-predecessors" - if (has_control) { - return TokenError("expects at most 1 'control-predecessors='"); - } - has_control = true; - if (!ParseControlPredecessors(instruction)) { + if (lexer_.GetKind() != TokKind::kLbrace && + lexer_.GetKind() != TokKind::kRbrace) { + return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true); + } + + // Tuple sharding. + // Allow empty tuple shardings. + if (lexer_.GetKind() != TokKind::kRbrace) { + do { + if (!ParseSingleSharding(sharding->add_tuple_shardings(), + /*lbrace_pre_lexed=*/false)) { return false; } - } else { - return TokenError(StrCat("unexpected attribute: ", attribute_name)); - } + } while (EatIfPresent(TokKind::kComma)); } + sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE); - return AddInstruction(name, instruction); + return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute"); } -// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('[' -// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list -bool HloParser::ParseSharding(HloInstruction* instruction) { - if (!ParseToken(TokKind::kLbrace, +// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? +// ('devices=' ('[' dims ']')* device_list)? '}' +// dims ::= int_list device_list ::= int_list +bool HloParser::ParseSingleSharding(OpSharding* sharding, + bool lbrace_pre_lexed) { + if (!lbrace_pre_lexed && + !ParseToken(TokKind::kLbrace, "expected '{' to start sharding attribute")) { return false; } + LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; std::vector devices; @@ -545,81 +1047,78 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { } } - OpSharding sharding; if (replicated) { if (!devices.empty()) { - return TokenError( - "replicated shardings should not have any devices assigned"); + return Error(loc, + "replicated shardings should not have any devices assigned"); } if (!ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError( - "replicated shardings should not have any tile shape set"); + return Error(loc, + "replicated shardings should not have any tile shape set"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); + sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { - return TokenError( - "maximal shardings should have exactly one device assigned"); + return Error(loc, + "maximal shardings should have exactly one device assigned"); } if (!ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError("maximal shardings should not have any tile shape set"); + return Error(loc, "maximal shardings should not have any tile shape set"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - sharding.add_tile_assignment_devices(devices[0]); + sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); + sharding->add_tile_assignment_devices(devices[0]); } else { if (devices.size() <= 1) { - return TokenError( - "non-maximal shardings must have more than one device assigned"); + return Error( + loc, "non-maximal shardings must have more than one device assigned"); } if (ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError("non-maximal shardings should have a tile shape set"); + return Error(loc, "non-maximal shardings should have a tile shape set"); } if (tile_assignment_dimensions.empty()) { - return TokenError( + return Error( + loc, "non-maximal shardings must have a tile assignment list including " "dimensions"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER); - *sharding.mutable_tile_shape() = tile_shape; + sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); + *sharding->mutable_tile_shape() = tile_shape; for (int64 dim : tile_assignment_dimensions) { - sharding.add_tile_assignment_dimensions(dim); + sharding->add_tile_assignment_dimensions(dim); } for (int64 device : devices) { - sharding.add_tile_assignment_devices(device); + sharding->add_tile_assignment_devices(device); } } - instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie()); lexer_.Lex(); return true; } // '{' name+ '}' -bool HloParser::ParseControlPredecessors(HloInstruction* instruction) { +bool HloParser::ParseInstructionNames( + std::vector* instructions) { if (!ParseToken(TokKind::kLbrace, - "expects '{' at the beginning of control predecessors")) { + "expects '{' at the beginning of instruction name list")) { return false; } + LocTy loc = lexer_.GetLoc(); do { string name; if (!ParseName(&name)) { - return TokenError("expects a control predecessor"); + return Error(loc, "expects a instruction name"); } - HloInstruction* pre = + HloInstruction* instr = tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); - if (!pre) { + if (!instr) { return TokenError( - StrCat("control predecessor ", name, " is not defined: ")); - } - Status status = pre->AddControlDependencyTo(instruction); - if (!status.ok()) { - return TokenError(StrCat("error adding control dependency for: ", name, - " status: ", status.ToString())); + Printf("instruction '%s' is not defined", name.c_str())); } + instructions->push_back(instr); } while (EatIfPresent(TokKind::kComma)); return ParseToken(TokKind::kRbrace, - "expects '}' at the end of control predecessors"); + "expects '}' at the end of instruction name list"); } bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, @@ -654,6 +1153,8 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index, switch (shape.element_type()) { case F16: return SetValueInLiteralHelper(value, linear_index, literal); + case BF16: + return SetValueInLiteralHelper(value, linear_index, literal); case F32: return SetValueInLiteralHelper(value, linear_index, literal); case F64: @@ -692,7 +1193,8 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, (std::numeric_limits::infinity() == value || -std::numeric_limits::infinity() == value))) { // Skip range checking for non-finite value. - } else if (literal->shape().element_type() == F16) { + } else if (literal->shape().element_type() == F16 || + literal->shape().element_type() == BF16) { if (value > kF16max || value < -kF16max) { return TokenError(StrCat( "value ", value, " is out of range for literal's primitive type ", @@ -778,12 +1280,6 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, // rank2345 ::= shape nested_array bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, const Shape& shape) { - const int64 size = ShapeUtil::ElementsIn(shape); - if (size == 0) { - *literal = Literal::CreateFromShape(shape); - return true; - } - const int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -884,20 +1380,22 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, } lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { + LocTy loc = lexer_.GetLoc(); int64 value; if (!ParseInt64(&value)) { - return TokenError(StrCat("expects integer for primitive type: ", + return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } if (!SetValueInLiteral(value, linear_index++, literal->get())) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { + LocTy loc = lexer_.GetLoc(); double value; if (!ParseDouble(&value)) { - return TokenError( - StrCat("expect floating point value for primitive type: ", - PrimitiveType_Name(shape.element_type()))); + return Error( + loc, StrCat("expect floating point value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); } if (!SetValueInLiteral(value, linear_index++, literal->get())) { return false; @@ -929,6 +1427,7 @@ bool HloParser::ParseOperands(std::vector* operands) { // empty } else { do { + LocTy loc = lexer_.GetLoc(); Shape shape; string name; if (!ParseShape(&shape) || !ParseName(&name)) { @@ -937,7 +1436,7 @@ bool HloParser::ParseOperands(std::vector* operands) { HloInstruction* instruction = tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); if (!instruction) { - return TokenError(StrCat("instruction does not exist: ", name)); + return Error(loc, StrCat("instruction does not exist: ", name)); } operands->push_back(instruction); } while (EatIfPresent(TokKind::kComma)); @@ -947,52 +1446,513 @@ bool HloParser::ParseOperands(std::vector* operands) { bool HloParser::ParseOperands(std::vector* operands, const int expected_size) { + LocTy loc = lexer_.GetLoc(); if (!ParseOperands(operands)) { return false; } if (expected_size != operands->size()) { - return TokenError(StrCat("expects ", expected_size, " operands, but has ", + return Error(loc, StrCat("expects ", expected_size, " operands, but has ", operands->size(), " operands")); } return true; } -// extra_attribute ::= ',' attribute_name value -template -bool HloParser::ParseExtraAttribute(T* value, - const string& expected_attribute) { - if (!ParseToken(TokKind::kComma, - "expects ',' in front of an extra attribute")) { +// sub_attributes ::= '{' (','? attribute)* '}' +bool HloParser::ParseSubAttributes( + const std::unordered_map& attrs) { + LocTy loc = lexer_.GetLoc(); + if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) { return false; } - string attribute_name; - if (!ParseAttributeName(&attribute_name) && - attribute_name != expected_attribute) { - return TokenError(StrCat("expects attribute name: ", expected_attribute)); + std::unordered_set seen_attrs; + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + } else { + do { + EatIfPresent(TokKind::kComma); + if (!ParseAttributeHelper(attrs, &seen_attrs)) { + return false; + } + } while (lexer_.GetKind() != TokKind::kRbrace); } - if (!ParseAttributeValue(value)) { - return TokenError( - StrCat("expects value for attribute: ", expected_attribute)); + // Check that all required attrs were seen. + for (const auto& attr_it : attrs) { + if (attr_it.second.required && + seen_attrs.find(attr_it.first) == seen_attrs.end()) { + return Error(loc, Printf("sub-attribute %s is expected but not seen", + attr_it.first.c_str())); + } + } + return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes"); +} + +// attributes ::= (',' attribute)* +bool HloParser::ParseAttributes( + const std::unordered_map& attrs) { + LocTy loc = lexer_.GetLoc(); + std::unordered_set seen_attrs; + while (EatIfPresent(TokKind::kComma)) { + if (!ParseAttributeHelper(attrs, &seen_attrs)) { + return false; + } + } + // Check that all required attrs were seen. + for (const auto& attr_it : attrs) { + if (attr_it.second.required && + seen_attrs.find(attr_it.first) == seen_attrs.end()) { + return Error(loc, Printf("attribute %s is expected but not seen", + attr_it.first.c_str())); + } + } + return true; +} + +bool HloParser::ParseAttributeHelper( + const std::unordered_map& attrs, + std::unordered_set* seen_attrs) { + LocTy loc = lexer_.GetLoc(); + string name; + if (!ParseAttributeName(&name)) { + return Error(loc, "error parsing attributes"); + } + VLOG(1) << "Parsing attribute " << name; + if (!seen_attrs->insert(name).second) { + return Error(loc, Printf("attribute %s already exists", name.c_str())); + } + auto attr_it = attrs.find(name); + if (attr_it == attrs.end()) { + return Error(loc, Printf("unexpected attribute %s", name.c_str())); + } + AttrTy attr_type = attr_it->second.attr_type; + void* attr_out_ptr = attr_it->second.result; + bool success = [&] { + LocTy attr_loc = lexer_.GetLoc(); + switch (attr_type) { + case AttrTy::kInt64: { + int64 result; + if (!ParseInt64(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kInt32: { + int64 result; + if (!ParseInt64(&result)) { + return false; + } + if (result != static_cast(result)) { + return Error(attr_loc, "value out of range for int32"); + } + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); + return true; + } + case AttrTy::kFloat: { + double result; + if (!ParseDouble(&result)) { + return false; + } + if (result > std::numeric_limits::max() || + result < std::numeric_limits::lowest()) { + return Error(attr_loc, "value out of range for float"); + } + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); + return true; + } + case AttrTy::kHloComputation: { + HloComputation* result; + if (!ParseComputationName(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kWindow: { + Window result; + if (!ParseWindow(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kConvolutionDimensionNumbers: { + ConvolutionDimensionNumbers result; + if (!ParseConvolutionDimensionNumbers(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kSharding: { + OpSharding sharding; + if (!ParseSharding(&sharding)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(sharding); + return true; + } + case AttrTy::kInstructionList: { + std::vector result; + if (!ParseInstructionNames(&result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kFusionKind: { + HloInstruction::FusionKind result; + if (!ParseFusionKind(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kBracedInt64List: { + std::vector result; + if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + &result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kSliceRanges: { + SliceRanges result; + if (!ParseSliceRanges(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kPaddingConfig: { + PaddingConfig result; + if (!ParsePaddingConfig(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kString: { + string result; + if (!ParseString(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kMetadata: { + OpMetadata result; + if (!ParseMetadata(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kDistribution: { + RandomDistribution result; + if (!ParseRandomDistribution(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } + } + }(); + if (!success) { + return Error(loc, Printf("error parsing attribute %s", name.c_str())); } return true; } -template <> -bool HloParser::ParseAttributeValue(HloComputation** value) { +bool HloParser::ParseComputationName(HloComputation** value) { string name; + LocTy loc = lexer_.GetLoc(); if (!ParseName(&name)) { - return TokenError("expects computation name"); + return Error(loc, "expects computation name"); } *value = tensorflow::gtl::FindPtrOrNull(computation_pool_, name); if (*value == nullptr) { - return TokenError(StrCat("computation does not exist: ", name)); + return Error(loc, StrCat("computation does not exist: ", name)); + } + return true; +} + +// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' +// The subattributes can appear in any order. 'size=' is required, others are +// optional. +bool HloParser::ParseWindow(Window* window) { + LocTy loc = lexer_.GetLoc(); + if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { + return false; + } + + std::vector size; + std::vector stride; + std::vector> pad; + std::vector lhs_dilate; + std::vector rhs_dilate; + std::vector rhs_reversal; + while (lexer_.GetKind() != TokKind::kRbrace) { + LocTy attr_loc = lexer_.GetLoc(); + string field_name; + if (!ParseAttributeName(&field_name)) { + return Error(attr_loc, "expects sub-attributes in window"); + } + bool ok = [&] { + if (field_name == "size") { + return ParseDxD("size", &size); + } + if (field_name == "stride") { + return ParseDxD("stride", &stride); + } + if (field_name == "lhs_dilate") { + return ParseDxD("lhs_dilate", &lhs_dilate); + } + if (field_name == "rhs_dilate") { + return ParseDxD("rls_dilate", &rhs_dilate); + } + if (field_name == "pad") { + return ParseWindowPad(&pad); + } + if (field_name == "rhs_reversal") { + return ParseDxD("rhs_reversal", &rhs_reversal); + } + return Error(loc, StrCat("unexpected attribute name: ", field_name)); + }(); + if (!ok) { + return false; + } + } + + if (size.empty()) { + return Error(loc, + "sub-attribute 'size=' is required in the window attribute"); + } + if (!stride.empty() && stride.size() != size.size()) { + return Error(loc, "expects 'stride=' has the same size as 'size='"); + } + if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) { + return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='"); + } + if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) { + return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='"); + } + if (!pad.empty() && pad.size() != size.size()) { + return Error(loc, "expects 'pad=' has the same size as 'size='"); + } + + for (int i = 0; i < size.size(); i++) { + window->add_dimensions()->set_size(size[i]); + if (!pad.empty()) { + window->mutable_dimensions(i)->set_padding_low(pad[i][0]); + window->mutable_dimensions(i)->set_padding_high(pad[i][1]); + } + // If some field is not present, it has the default value. + window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]); + window->mutable_dimensions(i)->set_base_dilation( + lhs_dilate.empty() ? 1 : lhs_dilate[i]); + window->mutable_dimensions(i)->set_window_dilation( + rhs_dilate.empty() ? 1 : rhs_dilate[i]); + window->mutable_dimensions(i)->set_window_reversal( + rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); + } + return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); +} + +// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. +// The string looks like "dim_labels=0bf_0io->0bf". +bool HloParser::ParseConvolutionDimensionNumbers( + ConvolutionDimensionNumbers* dnums) { + if (lexer_.GetKind() != TokKind::kDimLabels) { + return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'"); + } + string str = lexer_.GetStrVal(); + + // The str is expected to have 3 items, lhs, rhs, out, and it must looks like + // lhs_rhs->out, that is, the first separator is "_" and the second is "->". + // So we replace the "->" with "_" and then split on "_". + str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->", + /*newsub=*/"_", + /*replace_all=*/false); + std::vector lhs_rhs_out = Split(str, "_"); + if (lhs_rhs_out.size() != 3) { + LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " + << str; + } + + const int64 rank = lhs_rhs_out[0].length(); + if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { + return TokenError( + "convolution lhs, rhs, and output must have the same rank"); + } + if (rank < 2) { + return TokenError("convolution rank must >=2"); + } + + auto is_unique = [](string str) -> bool { + std::sort(str.begin(), str.end()); + return std::unique(str.begin(), str.end()) == str.end(); + }; + + // lhs + { + const string& lhs = lhs_rhs_out[0]; + if (!is_unique(lhs)) { + return TokenError( + StrCat("expects unique lhs dimension numbers, but sees ", lhs)); + } + for (int i = 0; i < rank - 2; i++) { + dnums->add_input_spatial_dimensions(-1); + } + for (int i = 0; i < rank; i++) { + char c = lhs[i]; + if (c == 'b') { + dnums->set_input_batch_dimension(i); + } else if (c == 'f') { + dnums->set_input_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + dnums->set_input_spatial_dimensions(c - '0', i); + } else { + return TokenError( + Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1)); + } + } } + // rhs + { + const string& rhs = lhs_rhs_out[1]; + if (!is_unique(rhs)) { + return TokenError( + StrCat("expects unique rhs dimension numbers, but sees ", rhs)); + } + for (int i = 0; i < rank - 2; i++) { + dnums->add_kernel_spatial_dimensions(-1); + } + for (int i = 0; i < rank; i++) { + char c = rhs[i]; + if (c == 'i') { + dnums->set_kernel_input_feature_dimension(i); + } else if (c == 'o') { + dnums->set_kernel_output_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + dnums->set_kernel_spatial_dimensions(c - '0', i); + } else { + return TokenError( + Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1)); + } + } + } + // output + { + const string& out = lhs_rhs_out[2]; + if (!is_unique(out)) { + return TokenError( + StrCat("expects unique output dimension numbers, but sees ", out)); + } + for (int i = 0; i < rank - 2; i++) { + dnums->add_output_spatial_dimensions(-1); + } + for (int i = 0; i < rank; i++) { + char c = out[i]; + if (c == 'b') { + dnums->set_output_batch_dimension(i); + } else if (c == 'f') { + dnums->set_output_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + dnums->set_output_spatial_dimensions(c - '0', i); + } else { + return TokenError( + Printf("expects [0-%lldbf] in output dimension numbers", rank - 1)); + } + } + } + + lexer_.Lex(); return true; } -template <> -bool HloParser::ParseAttributeValue(int64* value) { - return ParseInt64(value); +// ::= '{' ranges '}' +// ::= /*empty*/ +// ::= range (',' range)* +// range ::= '[' start ':' limit (':' stride)? ']' +// +// The slice ranges are printed as: +// +// {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...} +// +// This function extracts the starts, limits, and strides as 3 vectors to the +// result. If stride is not present, stride is 1. For example, if the slice +// ranges is printed as: +// +// {[2:3:4], [5:6:7], [8:9]} +// +// The the parsed result will be: +// +// {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}} +// +bool HloParser::ParseSliceRanges(SliceRanges* result) { + if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { + return false; + } + std::vector> ranges; + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); + } + do { + LocTy loc = lexer_.GetLoc(); + ranges.emplace_back(); + if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon, + &ranges.back())) { + return false; + } + const auto& range = ranges.back(); + if (range.size() != 2 && range.size() != 3) { + return Error(loc, Printf("expects [start:limit:step] or [start:limit], " + "but sees %ld elements.", + range.size())); + } + } while (EatIfPresent(TokKind::kComma)); + + for (const auto& range : ranges) { + result->starts.push_back(range[0]); + result->limits.push_back(range[1]); + result->strides.push_back(range.size() == 3 ? range[2] : 1); + } + return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); +} + +// int64list ::= start int64_elements end +// int64_elements +// ::= /*empty*/ +// ::= int64_val (delim int64_val)* +bool HloParser::ParseInt64List(const TokKind start, const TokKind end, + const TokKind delim, + std::vector* result) { + if (!ParseToken(start, StrCat("expects an int64 list starting with ", + TokKindToString(start)))) { + return false; + } + if (lexer_.GetKind() == end) { + // empty + } else { + do { + int64 i; + if (!ParseInt64(&i)) { + return false; + } + result->push_back(i); + } while (EatIfPresent(delim)); + } + return ParseToken( + end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } // param_list ::= '(' param_list1 ')' @@ -1070,12 +2030,171 @@ bool HloParser::ParseAttributeName(string* result) { return true; } +bool HloParser::ParseString(string* result) { + VLOG(1) << "ParseString"; + if (lexer_.GetKind() != TokKind::kString) { + return TokenError("expects string"); + } + *result = lexer_.GetStrVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseDxD(const string& name, std::vector* result) { + LocTy loc = lexer_.GetLoc(); + if (!result->empty()) { + return Error(loc, + Printf("sub-attribute '%s=' already exists", name.c_str())); + } + // 1D + if (lexer_.GetKind() == TokKind::kInt) { + int64 number; + if (!ParseInt64(&number)) { + return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); + } + result->push_back(number); + return true; + } + // 2D or higher. + if (lexer_.GetKind() == TokKind::kDxD) { + string str = lexer_.GetStrVal(); + if (!SplitAndParseAsInts(str, 'x', result)) { + return Error(loc, + Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + } + lexer_.Lex(); + return true; + } + return TokenError("expects token type kInt or kDxD"); +} + +bool HloParser::ParseWindowPad(std::vector>* pad) { + LocTy loc = lexer_.GetLoc(); + if (!pad->empty()) { + return Error(loc, "sub-attribute 'pad=' already exists"); + } + if (lexer_.GetKind() != TokKind::kPad) { + return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); + } + string str = lexer_.GetStrVal(); + std::vector padding_str = Split(str, 'x'); + for (int i = 0; i < padding_str.size(); i++) { + std::vector low_high; + if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || + low_high.size() != 2) { + return Error(loc, + "expects padding_low and padding_high separated by '_'"); + } + pad->push_back(low_high); + } + lexer_.Lex(); + return true; +} + +// This is the inverse xla::ToString(PaddingConfig). The padding config string +// looks like "0_0_0x3_3_1". The string is first separated by 'x', each +// substring represents one PaddingConfigDimension. The substring is 3 (or 2) +// numbers joined by '_'. +bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { + if (lexer_.GetKind() != TokKind::kPad) { + return TokenError("expects padding config, e.g., '0_0_0x3_3_1'"); + } + LocTy loc = lexer_.GetLoc(); + string str = lexer_.GetStrVal(); + std::vector padding_str = Split(str, 'x'); + for (const auto& padding_dim_str : padding_str) { + std::vector padding_dim; + if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || + (padding_dim.size() != 2 && padding_dim.size() != 3)) { + return Error(loc, + "expects padding config pattern like 'low_high_interior' or " + "'low_high'"); + } + auto* dim = padding->add_dimensions(); + dim->set_edge_padding_low(padding_dim[0]); + dim->set_edge_padding_high(padding_dim[1]); + dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0); + } + lexer_.Lex(); + return true; +} + +// '{' metadata_string '}' +bool HloParser::ParseMetadata(OpMetadata* metadata) { + std::unordered_map attrs; + optional op_type; + optional op_name; + optional source_file; + optional source_line; + attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; + attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; + attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; + attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line}; + if (!ParseSubAttributes(attrs)) { + return false; + } + if (op_type) { + metadata->set_op_type(*op_type); + } + if (op_name) { + metadata->set_op_name(*op_name); + } + if (source_file) { + metadata->set_source_file(*source_file); + } + if (source_line) { + metadata->set_source_line(*source_line); + } + return true; +} + bool HloParser::ParseOpcode(HloOpcode* result) { VLOG(1) << "ParseOpcode"; - if (lexer_.GetKind() != TokKind::kOpcode) { + if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects opcode"); } - *result = lexer_.GetOpcodeVal(); + string val = lexer_.GetStrVal(); + auto status_or_result = StringToHloOpcode(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects opcode but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { + VLOG(1) << "ParseFusionKind"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects fusion kind"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToFusionKind(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects fusion kind but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseRandomDistribution(RandomDistribution* result) { + VLOG(1) << "ParseRandomDistribution"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects random distribution"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToRandomDistribution(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects random distribution but sees: %s, error: %s", + val.c_str(), status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); lexer_.Lex(); return true; } @@ -1141,20 +2260,20 @@ bool HloParser::EatIfPresent(TokKind kind) { return true; } -bool HloParser::AddInstruction(const string& name, - HloInstruction* instruction) { +bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, + LocTy name_loc) { auto result = instruction_pool_.insert({name, instruction}); if (!result.second) { - return TokenError(StrCat("instruction already exists: ", name)); + return Error(name_loc, StrCat("instruction already exists: ", name)); } return true; } -bool HloParser::AddComputation(const string& name, - HloComputation* computation) { +bool HloParser::AddComputation(const string& name, HloComputation* computation, + LocTy name_loc) { auto result = computation_pool_.insert({name, computation}); if (!result.second) { - return TokenError(StrCat("computation already exists: ", name)); + return Error(name_loc, StrCat("computation already exists: ", name)); } return true; } @@ -1165,7 +2284,7 @@ StatusOr> Parse(StringPiece str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { - return InvalidArgument("Syntax error: %s", parser.GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); } return parser.ConsumeHloModule(); } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 359256f0646367f8af13439b30067624defcd44c..61d8902855f47a11716f8a60b082c6c25ea9b8af 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -25,6 +25,7 @@ namespace tools { namespace { using tensorflow::StringPiece; +using tensorflow::strings::StrCat; struct TestData { string test_name; @@ -35,6 +36,10 @@ string TestDataToString(const ::testing::TestParamInfo& data) { return data.param.test_name; } +// For each string below, we check that: +// - we parse it to an HloModule successfully, and +// - the stringification of the resulting HloModule is equal to our original +// string. std::vector CreateTestCases() { // clang-format off return std::vector({ @@ -43,10 +48,11 @@ std::vector CreateTestCases() { "AxpyParam", R"(HloModule axpy_module: -ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { - %alpha = f32[2,4]{1,0} parameter(0) +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} %x = f32[2,4]{1,0} parameter(1) - %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) %y = f32[2,4]{1,0} parameter(2) ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) } @@ -59,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { R"(HloModule constant_pred_module: ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true) + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} } )" @@ -77,12 +83,35 @@ ENTRY %constant_s32 () -> s32[] { }, // f32 constant, but the value is not a decimal { -"ConstantF32", R"(HloModule ConstantF32_module: +"ConstantF32", +R"(HloModule ConstantF32_module: ENTRY %ConstantF32.v4 () -> f32[] { ROOT %constant = f32[] constant(42) } +)" +}, +// f32 constant, rank 1 empty array. +{ +"ConstantF32R1Empty", +R"(HloModule ConstantF32Empty_module: + +ENTRY %ConstantF32Empty.v4 () -> f32[0] { + ROOT %constant = f32[0]{0} constant({}) +} + +)" +}, +// f32 constant, rank 4 empty array. +{ +"ConstantF32R4Empty", +R"(HloModule ConstantF32R4Empty_module: + +ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { + ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } }) +} + )" }, // constant 4D @@ -117,6 +146,17 @@ ENTRY %ConstantF16.v4 () -> f16[] { ROOT %constant = f16[] constant(500) } +)" +}, +// bf16 +{ +"BF16", +R"(HloModule BF16: + +ENTRY %BF16.v4 () -> bf16[] { + ROOT %constant = bf16[] constant(500) +} + )" }, // constant + constant @@ -151,7 +191,7 @@ ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f3 %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} %v2 = f32[4]{0} parameter(1), sharding={maximal device=1} %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated} - ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2) + ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={} } )" @@ -179,6 +219,19 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) } +)" +}, +{ +"ShardedTupleCreate", +R"(HloModule ShardedTupleCreate_module: + +ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}} +} + )" }, // int32 result = 0; @@ -212,9 +265,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { R"(HloModule TwoSendRecvBothWayRecvFist_module: ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = f32[] recv(), channel_id=15, sharding={maximal device=1} - ROOT %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1} + ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1} + %constant = f32[] constant(2.1), sharding={maximal device=0} + %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0} } )" @@ -248,7 +303,445 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { } )" +}, +// reduce window +{ +"ReduceWindow", +R"(HloModule R4UnitWindow_module: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] { + %operand = f32[13,12,8,15]{0,3,2,1} parameter(0) + %constant = f32[] constant(0) + ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3 +} + +)" +}, +// reduce window on scalar +{ +"ReduceWindowScalar", +R"(HloModule reduce_window_scalar: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %R4UnitWindowScalar () -> f32[] { + %constant = f32[] constant(42) + %constant.1 = f32[] constant(1) + ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3 +} + +)" +}, +// convolution +{ +"Convolution", +R"(HloModule Convolve1D1Window_0_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f +} + +)" +}, +// convolution rank 2 +{ +"ConvolutionR2", +R"(HloModule ConvolveR2_module: + +ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { + %input = f32[1,2]{1,0} parameter(0) + %filter = f32[1,1]{1,0} parameter(1) + ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf +} + +)" +}, +// convolution backward +{ +"ConvolutionBackward", +R"(HloModule ConvolveBackward_module: + +ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { + %input = f32[128,7,7,512]{0,3,2,1} parameter(0) + %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f +} + +)" +}, +// reverse(constant) +{ +"Reverse4D", +R"(HloModule Reverse4DFloatArrayOnDim01_module: + +ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { + %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) + ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1} +} + +)" +}, +// concat +{ +"Concat", +R"(HloModule Concat2x3With2x5_module: + +ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { + %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) + %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) + ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1} +} + +)" +}, +// map +{ +"Map", +R"(HloModule MapBinaryAdder_module: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] { + %param0 = f32[4]{0} parameter(0) + %param1 = f32[4]{0} parameter(1) + ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3 +} + +)" +}, +// reduce +{ +"Reduce", +R"(HloModule ReduceR3ToR2_module: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) } + +ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] { + %input = f32[8,16,256]{2,1,0} parameter(0) + %constant = f32[] constant(0) + ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3 +} + +)" +}, +// select and scatter +{ +"SelectAndScatter", +R"(HloModule R4F32OverlapSmall_module: + +%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) +} + +%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { + %lhs.1 = f32[] parameter(0) + %rhs.1 = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1) +} + +ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { + %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) + %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) + %constant.2 = f32[] constant(0) + ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3 +} + +)" +}, +// select and scatter on scalar +{ +"SelectAndScatterScalar", +R"(HloModule select_and_scatter_scalar: + +%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) +} + +%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { + %lhs.1 = f32[] parameter(0) + %rhs.1 = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1) +} + +ENTRY %SelectAndScatterScalar () -> f32[] { + %constant = f32[] constant(42) + %constant.1 = f32[] constant(1) + %constant.2 = f32[] constant(2) + ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3 +} + +)" +}, +// slice +{ +"Slice", +R"(HloModule slice_module: + +ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { + %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) + ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]} +} + +)" +}, +// slice, no stride +{ +"SliceNoStride", +R"(HloModule Slice3x3x3_To_1x3x3_F32_module: + +ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { + %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) + ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]} +} + +)" +}, +// slice R0 +{ +"SliceR0", +R"(HloModule SliceR0_module: + +ENTRY %SliceR0.v2 () -> s32[] { + %constant = s32[] constant(1) + ROOT %slice = s32[] slice(s32[] %constant), slice={} +} + +)" +}, +// transpose +{ +"Transpose", +R"(HloModule Transpose_module: + +ENTRY %Transpose.v2 () -> s32[1,2,3] { + %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) + ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2} +} + +)" +}, +// Dynamic slice +{ +"DynamicSlice", +R"(HloModule DynamicSlice_module: + +ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] { + %original_parameter = s32[2,2,258]{2,1,0} parameter(0) + %constant = s32[1]{0} constant({0}) + %start_index = s32[1]{0} parameter(1) + %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0} + ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258} +} + +)" +}, +// Dynamic update slice +{ +"DynamicUpdateSlice", +R"(HloModule DynamicUpdateSlice_module: + +ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] { + %input = s32[1,1,25,1]{3,2,1,0} parameter(0) + %update = s32[1,1,2,1]{3,2,1,0} parameter(1) + %start_indices = s32[4]{0} parameter(2) + ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices) +} + +)" +}, +// batch norm training +{ +"BatchNormTraining", +R"(HloModule BasicTraining_module: + +ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { + %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } }) + %constant.1 = f32[2]{0} constant({2, 3}) + %constant.2 = f32[2]{0} constant({1, 2}) + ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 +} + +)" +}, +// batch norm inference +{ +"BatchNormInference", +R"(HloModule BatchNormInference_module: + +ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] { + %input = f32[2,2,2,2]{3,2,1,0} parameter(0) + %offset = f32[2]{0} parameter(1) + %scale = f32[2]{0} parameter(2) + %mean = f32[2]{0} parameter(3) + %variance = f32[2]{0} parameter(4) + ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0 +} + +)" +}, +// batch norm grad +{ +"BatchNormGrad", +R"(HloModule BatchNormGrad_module: + +ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) { + %input = f32[2,2,2,2]{3,2,1,0} parameter(0) + %scale = f32[2]{0} parameter(1) + %mean = f32[2]{0} parameter(2) + %variance = f32[2]{0} parameter(3) + %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4) + ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0 +} + +)" +}, +// pad +{ +"Pad", +R"(HloModule Pad1DS3Array_module: + +ENTRY %Pad1DS3Array.v3 () -> f32[8] { + %constant = f32[3]{0} constant({1, 2, 3}) + %constant.1 = f32[] constant(0.1) + ROOT %pad = f32[8]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1 +} + +)" +}, +// pad has interior +{ +"PadHasInterior", +R"(HloModule PadHasInterior_module: + +ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] { + %input = f32[1,25,7,7]{3,2,1,0} parameter(0) + %constant = f32[] constant(-5.123) + ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0 +} + +)" +}, +// fusion +{ +"Fusion", +R"(HloModule fusion_module: + +%fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] { + %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) + %constant.1.param_1 = f32[2]{0} parameter(1) + %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1} + ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast) +} + +ENTRY %fusion.v3 () -> f32[3,2,1,1] { + %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + %constant.1 = f32[2]{0} constant({3.14, 4.25}) + ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation +} + +)" +}, +// infeed/outfeed +{ +"InfeedOutfeed", +R"(HloModule outfeed_module: + +ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) { + %infeed = (u32[3]{0}, pred[]) infeed() + %outfeed = () outfeed((u32[3]{0}, pred[]) %infeed) + ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed() + %outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1) +} + +)" +}, +// Rng +{ +"Rng", +R"(HloModule rng_module: + +ENTRY %Rng () -> f32[8] { + %constant = f32[] constant(0) + %constant.1 = f32[] constant(1) + ROOT %rng = f32[8]{0} rng(f32[] %constant, f32[] %constant.1), distribution=rng_uniform +} + +)" +}, +// Reduce precision +{ +"ReducePrevison", +R"(HloModule reduce_precision: + +ENTRY %ReducePrecision () -> f32[1] { + %constant = f32[1]{0} constant({3.14159}) + ROOT %reduce-precision = f32[1]{0} reduce-precision(f32[1]{0} %constant), exponent_bits=8, mantissa_bits=10 +} + +)" +}, +// Conditional +{ +"Conditional", +R"(HloModule conditional: + +%Negate (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + ROOT %negate = f32[] negate(f32[] %x) +} + +%Identity (y: f32[]) -> f32[] { + %y = f32[] parameter(0) + ROOT %copy = f32[] copy(f32[] %y) +} + +ENTRY %Parameters1.v4 () -> f32[] { + %constant = pred[] constant(true) + %constant.1 = f32[] constant(56) + %constant.2 = f32[] constant(12) + ROOT %conditional = f32[] conditional(pred[] %constant, f32[] %constant.1, f32[] %constant.2), true_computation=%Negate, false_computation=%Identity +} + +)" +}, + +// CustomCall +{ +"CustomCall", +R"(HloModule custom_call: + +ENTRY %CustomCall () -> f32[1,2,3] { + %constant = f32[1]{0} constant({12345}) + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" +} + +)" +}, }); // clang-format on } @@ -261,16 +754,19 @@ class HloParserTest : public ::testing::Test, << "'" << s << "' does not contain '" << expected << "'"; } - void ExpectSuccess() { + // Expects "ToString(Parse(string)) == string", that is, parses the string, + // asserts that it succeeded, stringifies the parsed module, and checks that + // the it equals the original string. + void ExpectEqual() { const string& original = GetParam().module_string; auto result = Parse(original); - TF_EXPECT_OK(result.status()); + TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString(/*include_large_constants=*/true)); } }; -TEST_P(HloParserTest, Run) { ExpectSuccess(); } +TEST_P(HloParserTest, Run) { ExpectEqual(); } INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, ::testing::ValuesIn(CreateTestCases()), @@ -427,6 +923,130 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { // printed as "300". } +TEST_F(HloParserTest, AttibutesAnyOrder) { + const string original = R"(HloModule any_order_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} +} + +)"; + TF_EXPECT_OK(Parse(original).status()); +} + +TEST_F(HloParserTest, InvalidDimLabels) { + string prefix = R"(HloModule invalid_dim_labels_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )"; + string suffix = R"( +} + +)"; + + ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); + + ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=010_1100->010", suffix)) + .status() + .error_message(), + "must have the same rank"); +} + +TEST_F(HloParserTest, UnexpectedAttribute) { + const string original = R"(HloModule unexpected_attr_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = (f32[], u32[]) recv(), channel_id=15 + %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + ROOT %constant = f32[] constant(2.1) + %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv + %send-done = () send-done((f32[], u32[]) %send), channel_id=16 +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "unexpected attribute calls"); +} + +TEST_F(HloParserTest, MissingAttribute) { + const string original = R"(HloModule missing_attr_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = (f32[], u32[]) recv(), channel_id=15 + %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + ROOT %constant = f32[] constant(-2.1) + %send = (f32[], u32[]) send(f32[] %constant) + %send-done = () send-done((f32[], u32[]) %send), channel_id=16 +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "attribute channel_id is expected but not seen"); +} + +TEST_F(HloParserTest, PredecessorUndefined) { + const string original = R"(HloModule pre_not_found_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = (f32[], u32[]) recv(), channel_id=15 + %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + ROOT %constant = f32[] constant(2.1) + %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done} + %send-done = () send-done((f32[], u32[]) %send), channel_id=16 +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "'done' is not defined"); +} + +TEST_F(HloParserTest, SliceAllowOmitStride1) { + const string original = R"(HloModule slice_module: + +ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { + %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) + ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]} +} + +)"; + TF_EXPECT_OK(Parse(original).status()); +} + +TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { + const string original = R"(HloModule window_pad_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1} +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "expects padding_low and padding_high separated by '_'"); +} + +TEST_F(HloParserTest, CommaBetweenSubAttributes) { + const string original = R"(HloModule test_comma_module: + +ENTRY %test_comma.v4 () -> f32[] { + ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"} +} + +)"; + TF_EXPECT_OK(Parse(original).status()); +} + } // namespace } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h index 9c2069e7568e46e89afc0fd43d0ff3d8492991fb..7928bee5c2097f353b182095a555c334d7b69c95 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -18,6 +18,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" + namespace xla { namespace tools { @@ -57,8 +60,12 @@ enum class TokKind { // Typed tokens. kName, // %foo kAttributeName, // dimensions= + kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} + kDxD, // [0-9]+(x[0-9]+)+ + kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kIdent, // other identifiers + kString, // "abcd\"\n" kShape, // f32[2,3]{1,0} - kOpcode, // add kInt, // 42 kDecimal, // 4.2 }; diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 89b26b8916b67eeb38852c9e91314187fc8a7d48..a7dc5862057047f7c56faeb211cc0b13992caec7 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -58,18 +59,26 @@ namespace xla { namespace tools { namespace { +// Command-line opts to this tool. See main() for descriptions of these +// fields. +struct Options { + string fake_infeed_shape; + bool use_fake_data = false; + bool print_result = true; + int num_runs = 1; +}; + // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; // otherwise, no infeed is performed. StatusOr> ReplayComputation( - const SessionModule& module, tensorflow::StringPiece fake_infeed_shape, - bool use_fake_data, Client* client) { + const SessionModule& module, Client* client, const Options& opts) { TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module)); std::vector> arguments; - if (use_fake_data) { + if (opts.use_fake_data) { arguments = MakeFakeArgumentsOrDie(computation, client); } else { // use recorded data if available for (const auto& proto : module.arguments()) { @@ -84,12 +93,12 @@ StatusOr> ReplayComputation( // concurrent infeed occur via the fake_infeed_shape. tensorflow::gtl::optional pool; - if (!fake_infeed_shape.empty()) { + if (!opts.fake_infeed_shape.empty()) { pool.emplace(tensorflow::Env::Default(), "infeed", /*num_threads=*/1); - pool->Schedule([fake_infeed_shape, client]() { + pool->Schedule([opts, client]() { StatusOr shape_status = - ShapeUtil::ParseShapeString(fake_infeed_shape); + ShapeUtil::ParseShapeString(opts.fake_infeed_shape); TF_CHECK_OK(shape_status.status()); Shape shape = std::move(shape_status).ValueOrDie(); StatusOr> data_status = MakeFakeLiteral(shape); @@ -106,11 +115,32 @@ StatusOr> ReplayComputation( for (auto& argument : arguments) { execute_arguments.push_back(argument.get()); } - return client->ExecuteAndTransfer(computation, execute_arguments); + + // Run the computation num_runs times, and return the result from the last + // execution. + std::unique_ptr result; + for (int i = 0; i < opts.num_runs; ++i) { + ExecutionProfile profile; + if (opts.print_result) { + TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( + computation, execute_arguments, + /*execution_options=*/nullptr, &profile)); + } else { + // If we're not printing the result, execute the computation but don't + // bother retrieving the result. This can be a significant speedup. + TF_RETURN_IF_ERROR(client + ->Execute(computation, execute_arguments, + /*execution_options=*/nullptr, &profile) + .status()); + } + LOG(INFO) << "Execution took " + << static_cast(profile.compute_time_ns()) / 1e9 << "s"; + } + + return std::move(result); } -int RealMain(tensorflow::gtl::ArraySlice args, - tensorflow::StringPiece fake_infeed_shape, bool use_fake_data) { +int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { Client* client = ClientLibrary::LocalClientOrDie(); tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; @@ -118,21 +148,24 @@ int RealMain(tensorflow::gtl::ArraySlice args, SessionModule module; TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); StatusOr> result_status = - ReplayComputation(module, fake_infeed_shape, use_fake_data, client); + ReplayComputation(module, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); exit_status = EXIT_FAILURE; continue; } + std::unique_ptr result = result_status.ConsumeValueOrDie(); - fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); - if (module.has_result()) { - fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(module.result().shape()).c_str(), - Literal(module.result()).ToString().c_str()); + if (result != nullptr) { + fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), + ShapeUtil::HumanString(result->shape()).c_str(), + result->ToString().c_str()); + if (module.has_result()) { + fprintf(stdout, "was %s:%s\n", + ShapeUtil::HumanString(module.result().shape()).c_str(), + Literal(module.result()).ToString().c_str()); + } } } return exit_status; @@ -143,13 +176,15 @@ int RealMain(tensorflow::gtl::ArraySlice args, } // namespace xla int main(int argc, char** argv) { - // Flags - xla::string fake_infeed_shape; - bool use_fake_data = false; + xla::tools::Options opts; const std::vector flag_list = { - tensorflow::Flag("use_fake_data", &use_fake_data, + tensorflow::Flag("use_fake_data", &opts.use_fake_data, "Replay computation using fake data"), - tensorflow::Flag("fake_infeed_shape", &fake_infeed_shape, + tensorflow::Flag("print_result", &opts.print_result, + "Print the result of the computation to stdout"), + tensorflow::Flag("num_runs", &opts.num_runs, + "Number of times to run each computation"), + tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); @@ -161,5 +196,5 @@ int main(int argc, char** argv) { tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] - return xla::tools::RealMain(args, fake_infeed_shape, use_fake_data); + return xla::tools::RealMain(args, opts); } diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 3b19ca321cad35aad18f7f498e08fd744ffbc371..9fa4297523bab0748863479be52dff1b7b523a8b 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/platform/types.h" #include @@ -32,6 +33,8 @@ using ::tensorflow::int16; using ::tensorflow::int32; using ::tensorflow::int64; +using ::tensorflow::bfloat16; + using ::tensorflow::uint8; using ::tensorflow::uint16; using ::tensorflow::uint32; diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 2624ef0252fd9482a600fe3aec07f7f328a86d69..fe5d29a6b655a89d559eb1214c2b8dd54d34094c 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -42,15 +42,15 @@ Status WithLogBacktrace(const Status& status) { } // namespace -ScopedLoggingTimer::ScopedLoggingTimer(const string& label, int32 vlog_level) - : label(label), vlog_level(vlog_level) { - if (VLOG_IS_ON(vlog_level)) { +ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled) + : enabled(enabled), label(label) { + if (enabled) { start_micros = tensorflow::Env::Default()->NowMicros(); } } ScopedLoggingTimer::~ScopedLoggingTimer() { - if (VLOG_IS_ON(vlog_level)) { + if (enabled) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); double secs = (end_micros - start_micros) / 1000000.0; @@ -191,9 +191,9 @@ std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, return output; } -bool IsIdentityPermutation(tensorflow::gtl::ArraySlice p) { - for (int64 i = 0; i < p.size(); ++i) { - if (p[i] != i) { +bool IsIdentityPermutation(tensorflow::gtl::ArraySlice permutation) { + for (int64 i = 0; i < permutation.size(); ++i) { + if (permutation[i] != i) { return false; } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index f58f57b44396c90a3820835a3d0ecc792aaa7cd0..b722095d1f38bf8a984c3ce9092a65f8e0baa911 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -50,13 +50,43 @@ using DimensionVector = tensorflow::gtl::InlinedVector; // RAII timer that logs with a given label the wall clock time duration in human // readable form. This differs from base's ElapsedTimer primarily in that it // spits out the human-readable duration form. +// +// By default, the timing traces are only printed at VLOG(1) and above: +// +// XLA_SCOPED_LOGGING_TIMER("fooing bar"); // nop if !VLOG_IS_ON(1). +// +// but you can control this via: +// +// XLA_SCOPED_LOGGING_TIMER_LEVEL("fooing bar", 2); // nop if !VLOG_IS_ON(2) +// +#define XLA_SCOPED_LOGGING_TIMER(label) \ + XLA_SCOPED_LOGGING_TIMER_HELPER(label, 1, __COUNTER__) +#define XLA_SCOPED_LOGGING_TIMER_LEVEL(label, level) \ + XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, __COUNTER__) + +// Helper for implementing macros above. Do not use directly. +// +// Forces the evaluation of "counter", which we expect is equal to __COUNTER__. +#define XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, counter) \ + XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) + +// Helper for macros above. Don't use directly. +#define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) \ + ::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \ + label, VLOG_IS_ON(level)) + +// RAII timer for XLA_SCOPED_LOGGING_TIMER and XLA_SCOPED_LOGGING_TIMER_LEVEL +// macros above. Recommended usage is via the macros so you don't have to give +// the timer a name or worry about calling VLOG_IS_ON yourself. struct ScopedLoggingTimer { - explicit ScopedLoggingTimer(const string& label, int32 vlog_level = 1); + // The timer does nothing if enabled is false. This lets you pass in your + // file's VLOG_IS_ON value. + ScopedLoggingTimer(const string& label, bool enabled); ~ScopedLoggingTimer(); - uint64 start_micros; + bool enabled; string label; - int32 vlog_level; + uint64 start_micros; }; // Given a vector, returns a MutableArraySlice that points at its diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 23161873a0b722dfbea34507fefc38a7a02c023d..293f0781a203d092a7996d5548de1dbf5bf32e4c 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -26,8 +26,8 @@ namespace xla { namespace window_util { /* static */ string ToString(const WindowDimension& dim) { - using tensorflow::strings::StrCat; using tensorflow::strings::StrAppend; + using tensorflow::strings::StrCat; string str = StrCat("(size=", dim.size()); if (dim.stride() != 1) { StrAppend(&str, ",stride=", dim.stride()); @@ -44,27 +44,30 @@ namespace window_util { if (dim.window_dilation() != 1) { StrAppend(&str, ",window_dilation=", dim.window_dilation()); } + if (dim.window_reversal()) { + StrAppend(&str, ",window_reversal"); + } StrAppend(&str, ")"); return str; } string ToString(const Window& window) { - using tensorflow::strings::StrCat; using tensorflow::strings::StrAppend; + using tensorflow::strings::StrCat; string str; - const auto add_field = [&]( - const char* heading, - std::function format) { - StrAppend(&str, heading, "="); - const char* prefix = ""; - for (const auto& window_dimension : window.dimensions()) { - StrAppend(&str, prefix, format(window_dimension)); - prefix = "x"; - } - }; - - add_field("window", + const auto add_field = + [&](const char* heading, + std::function format) { + StrAppend(&str, heading, "="); + const char* prefix = ""; + for (const auto& window_dimension : window.dimensions()) { + StrAppend(&str, prefix, format(window_dimension)); + prefix = "x"; + } + }; + + add_field("size", [](const WindowDimension& dim) { return StrCat(dim.size()); }); if (HasStride(window)) { add_field(" stride", @@ -85,6 +88,11 @@ string ToString(const Window& window) { return StrCat(dim.window_dilation()); }); } + if (HasWindowReversal(window)) { + add_field(" rhs_reversal", [](const WindowDimension& dim) { + return StrCat(dim.window_reversal() ? 1 : 0); + }); + } return str; } @@ -138,6 +146,15 @@ bool HasWindowDilation(const Window& window) { return false; } +bool HasWindowReversal(const Window& window) { + for (const auto& dim : window.dimensions()) { + if (dim.window_reversal()) { + return true; + } + } + return false; +} + bool HasDilation(const Window& window) { return HasBaseDilation(window) || HasWindowDilation(window); } diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 235cb2d59d451a25dc4f824ab488f8cef6b03bfb..125900dac0c5ab478b834c315b4a438c9238ef6d 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -39,6 +39,8 @@ bool HasBaseDilation(const Window& window); bool HasWindowDilation(const Window& window); bool HasDilation(const Window& window); +bool HasWindowReversal(const Window& window); + // Returns the new bound after dilation. // // If a window with the given bound in some dimension is dilated with the given diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 3fa5bcc1df4f0294582b6c74735fef08c87433eb..6b136d333bbf079efd314833f46fe3b98743fbac 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -17,3 +17,5 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0): protoc="@protobuf_archive//:protoc", testonly=testonly, visibility=visibility,) + +ORC_JIT_MEMORY_MAPPER_TARGETS = [] diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 710bb6ff25bf649693165c5e9fb6bc50e81db4ca..127e5e81ac6d21945c7125ef913d236e8892758e 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -167,6 +167,14 @@ message DebugOptions { // computation will run 2! * 4! times. bool xla_test_all_input_layouts = 91; + // Assign colors based on sharding information when generating the Graphviz + // HLO graph. + bool xla_hlo_graph_sharding_color = 92; + + // Prefix the name scopes of the TF graph exports with "devX" device + // assignments, if available. + bool xla_hlo_tfgraph_device_scopes = 93; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map xla_backend_extra_options = 500; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 06987e0044d7f69637c9ca0e1a2b40d91cd74713..215707634bc29263bc1ef472f498ac1bb1ca9181 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -46,6 +46,12 @@ enum PrimitiveType { // converted to f16 from f32 at arbirary points in the computation. F16 = 10; F32 = 11; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent + // and 7 bits for the mantissa. + BF16 = 16; + F64 = 12; // Complex values of fixed width. @@ -63,6 +69,8 @@ enum PrimitiveType { // An opaque type used for passing context specific data to a custom // operation. OPAQUE = 14; + + // Next = 17 } // Describes the value held inside padding elements. @@ -310,7 +318,10 @@ message LiteralProto { repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated LiteralProto tuple_literals = 10; - bytes f16s = 11; // Note: the F16s are encoded in little endian byte order + // The F16s and BF16s are encoded in little endian byte order + bytes f16s = 11; + bytes bf16s = 13; + // Next = 14 } message WindowDimension { @@ -346,6 +357,10 @@ message WindowDimension { // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly // placed between each base area element. See documentation for convolution. int64 base_dilation = 6; + + // Window reversal means that this dimension was logically reversed before the + // operation. + bool window_reversal = 7; } // Describes the windowing in an operation such as convolution. @@ -402,15 +417,9 @@ message ConvolutionDimensionNumbers { // The number of the dimension that represents features in the input. int64 input_feature_dimension = 8; - // The number of the dimension that represents batch in the output. - int64 output_batch_dimension = 9; - - // The number of the dimension that represents features in the output. - int64 output_feature_dimension = 10; - // The dimension numbers for the spatial dimensions that the window - // moves through in the input (lhs) and output. - repeated int64 spatial_dimensions = 5; + // moves through in the input. + repeated int64 input_spatial_dimensions = 11; // The number of the dimension that represents input features in the // convolutional kernel (rhs). @@ -424,12 +433,24 @@ message ConvolutionDimensionNumbers { // moves through in the kernel (rhs). window.strides(0) is the // stride in the kernel_spatial_dimensions(0) dimension. repeated int64 kernel_spatial_dimensions = 6; + + // The number of the dimension that represents batch in the output. + int64 output_batch_dimension = 9; + + // The number of the dimension that represents features in the output. + int64 output_feature_dimension = 10; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the output. + repeated int64 output_spatial_dimensions = 12; + + // Next = 13 }; message ConvolveRequest { ComputationDataHandle lhs = 2; ComputationDataHandle rhs = 3; // This is the filter/kernel. - Window window = 4; // Describes the filter/kenel. + Window window = 4; // Describes the filter/kernel. ConvolutionDimensionNumbers dimension_numbers = 5; } @@ -477,6 +498,23 @@ message CustomCallRequest { Shape shape = 4; } +message DotDimensionNumbers { + // The dimension numbers that represent the 'lhs' contracting dimensions. + repeated int64 lhs_contracting_dimensions = 1; + // The dimension numbers that represent the 'rhs' contracting dimensions. + repeated int64 rhs_contracting_dimensions = 2; + // The dimension numbers that represent the 'lhs' batch dimensions. + repeated int64 lhs_batch_dimensions = 3; + // The dimension numbers that represent the 'rhs' batch dimensions. + repeated int64 rhs_batch_dimensions = 4; +}; + +message DotRequest { + ComputationDataHandle lhs = 2; + ComputationDataHandle rhs = 3; + DotDimensionNumbers dimension_numbers = 4; +} + message MapRequest { repeated ComputationDataHandle operands = 2; ComputationHandle to_apply = 3; @@ -630,6 +668,14 @@ message ConcatenateRequest { int64 dimension = 3; } +message ConditionalRequest { + ComputationDataHandle predicate = 2; + ComputationDataHandle true_operand = 3; + ComputationHandle true_computation = 4; + ComputationDataHandle false_operand = 5; + ComputationHandle false_computation = 6; +} + message WhileRequest { ComputationHandle condition = 2; ComputationHandle body = 3; @@ -711,9 +757,6 @@ enum BinaryOperation { BINOP_LT = 9; BINOP_NE = 10; - // Dot product, matrix multiply. - BINOP_DOT = 12; - // Element-wise maximum. BINOP_MAX = 14; @@ -825,8 +868,10 @@ message OpSharding { REPLICATED = 0; // This sharding is maximal - one device runs the entire operation. MAXIMAL = 1; - // Neither of the above; tile_shape and tile_assignment are both used. - OTHER = 2; + // This sharding is a tuple - only the tuple_shardings field is valid. + TUPLE = 2; + // None of the above; tile_shape and tile_assignment are both used. + OTHER = 3; } Type type = 1; // The shape of the sharded tile. @@ -838,6 +883,13 @@ message OpSharding { // Flattened list of device IDs. The order of flattening is the same as used // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). repeated int64 tile_assignment_devices = 4; + // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, + // in pre-order. The tuple shape could be nested; here we store just a + // flattened list of all leaves in the tuple shape. Note that the tuple shape + // is not stored here; shardings do not store the shapes to which they are + // applied, this is inferred from the instruction this sharding gets attached + // to. + repeated OpSharding tuple_shardings = 5; } message OpRequest { @@ -855,6 +907,7 @@ message OpRequest { ConvolveRequest convolve_request = 8; CrossReplicaSumRequest cross_replica_sum_request = 9; CustomCallRequest custom_call_request = 10; + DotRequest dot_request = 43; DynamicSliceRequest dynamic_slice_request = 11; DynamicUpdateSliceRequest dynamic_update_slice_request = 12; GetTupleElementRequest get_tuple_element_request = 13; @@ -883,7 +936,9 @@ message OpRequest { BatchNormGradRequest batch_norm_grad_request = 37; BatchNormInferenceRequest batch_norm_inference_request = 38; FftRequest fft_request = 41; - // Next: 42 + ConvertRequest bitcast_convert_request = 42; + ConditionalRequest conditional_request = 44; + // Next: 45 } } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 3d53cbba5652c902855972f6e4e3ee78a3e1bcc7..604c41bf8acc910b47f8ee4a871d4740a2f1ba2f 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -9,7 +9,12 @@ load("//third_party/mpi:mpi.bzl", "if_mpi") py_library( name = "contrib_py", - srcs = glob(["**/*.py"]), + srcs = glob( + ["**/*.py"], + exclude = [ + "**/*_test.py", + ], + ), srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ @@ -51,6 +56,7 @@ py_library( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", + "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", @@ -63,6 +69,7 @@ py_library( "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", + "//tensorflow/contrib/periodic_resample:init_py", "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 3068e9ed8f53e3e0f7cbf2d0222121a5752a2a56..08247c6b38a4df663ad28a6b4d3c41a1da41a020 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -55,6 +55,7 @@ from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt +from tensorflow.contrib import periodic_resample from tensorflow.contrib import predictor from tensorflow.contrib import quantization from tensorflow.contrib import quantize @@ -79,6 +80,7 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager +from tensorflow.contrib.lite.python import lite from tensorflow.contrib.ndlstm import python as ndlstm from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md index f49e5857fe5255c2459793cb1389052a2ff5f88f..c7c128bf14f03d3769ef08e83da61f6d2f91fbd2 100644 --- a/tensorflow/contrib/android/README.md +++ b/tensorflow/contrib/android/README.md @@ -15,9 +15,9 @@ For prebuilt libraries, see the page for a recent build. The TensorFlow Inference Interface is also available as a -[JCenter package](https://bintray.com/google/tensorflow/tensorflow-android) and -can be included quite simply in your android project with a couple of lines in -the project's `build.gradle` file: +[JCenter package](https://bintray.com/google/tensorflow/tensorflow) +(see the tensorflow-android directory) and can be included quite simply in your +android project with a couple of lines in the project's `build.gradle` file: ``` allprojects { diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc index 9e4d3290c3d99fab42f512f7144defde54f8ece8..380a652435ad089f46f3ca80e4fd43097fd96e10 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.cc +++ b/tensorflow/contrib/android/asset_manager_filesystem.cc @@ -97,7 +97,7 @@ class RandomAccessFileFromAsset : public RandomAccessFile { off64_t new_offset = AAsset_seek64(asset.get(), offset, SEEK_SET); off64_t length = AAsset_getLength64(asset.get()); if (new_offset < 0) { - result->set(scratch, 0); + *result = StringPiece(scratch, 0); return errors::OutOfRange("Read after file end."); } const off64_t region_left = @@ -106,7 +106,7 @@ class RandomAccessFileFromAsset : public RandomAccessFile { if (read < 0) { return errors::Internal("Error reading from asset."); } - result->set(scratch, region_left); + *result = StringPiece(scratch, region_left); return (region_left == to_read) ? Status::OK() : errors::OutOfRange("Read less bytes than requested."); diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt index 25ada5ba27aa167e4aaf4cebd6517e3b80aa1058..a115d1610e2334a6626f29674f3dd195e3a3c648 100644 --- a/tensorflow/contrib/android/cmake/CMakeLists.txt +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -34,10 +34,12 @@ add_library(lib_tf STATIC IMPORTED ) set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION ${PREBUILT_DIR}/lib/libtensorflow-core.a) # Change to compile flags should be replicated into bazel build file +# TODO: Consider options other than -O2 for binary size. +# e.g. -Os for gcc, and -Oz for clang. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \ -std=c++11 -fno-rtti -fno-exceptions \ -O2 -Wno-narrowing -fomit-frame-pointer \ - -mfpu=neon -mfloat-abi=softfp -fPIE \ + -mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \ -ftemplate-depth=900 \ -DGOOGLE_PROTOBUF_NO_RTTI \ -DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER") diff --git a/tensorflow/contrib/android/cmake/README.md b/tensorflow/contrib/android/cmake/README.md index 6f19b657fe72064bd7b005b568540cd52a5e19e8..934b58c7242fc06064ee3c06bc8f4c2740bd24ef 100644 --- a/tensorflow/contrib/android/cmake/README.md +++ b/tensorflow/contrib/android/cmake/README.md @@ -14,7 +14,7 @@ Add TensorFlow-Android-Inference as a dependency of your Android application ``` include ':TensorFlow-Android-Inference' -findProject(":TensorFlow-Android-Inference").projectDir = +findProject(":TensorFlow-Android-Inference").projectDir = new File("${/path/to/tensorflow_repo}/contrib/android/cmake") ``` diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 1f423a7a5bf6a115dc627ddd6f5e98c074282585..dc5b9fb88742d78d0f40207b589e29451a6358dd 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -160,7 +160,7 @@ public class TensorFlowInferenceInterface { throw new RuntimeException("Failed to load model from the input stream", e); } } - + /* * Construct a TensorFlowInferenceInterface with provided Graph * @@ -168,7 +168,7 @@ public class TensorFlowInferenceInterface { */ public TensorFlowInferenceInterface(Graph g) { prepareNativeRuntime(); - + // modelName is redundant here, here is for // avoiding error in initialization as modelName is marked final. this.modelName = ""; @@ -290,7 +290,7 @@ public class TensorFlowInferenceInterface { */ public void feed(String inputName, boolean[] src, long... dims) { byte[] b = new byte[src.length]; - + for (int i = 0; i < src.length; i++) { b[i] = src[i] ? (byte) 1 : (byte) 0; } diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h index 6ed177e001758ad8c566c7965e1ec10ae5235fc8..a2cb146b8d69b6cc0eda8912a9c840ac4e0c7030 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#include #include #include #include +#include #include #include @@ -42,19 +44,36 @@ template class ASBSQueue; } // namespace internal +// EXPERIMENTAL: API MAY BE SUBJECTED TO SUDDEN CHANGES. +// // Shared batch scheduler designed to minimize latency. The scheduler keeps // track of a number of queues (one per model or model version) which are // continuously enqueuing requests. The scheduler groups the requests into // batches which it periodically sends off for processing (see // shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler // prioritizes batches by age (i.e. the batch's oldest request) irrespective of -// queue. The scheduler will process the oldest batch at an adjustable rate, -// regardless of batch size. The user can provide feedback to help set this rate -// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc). +// queue or batch size. +// +// The scheduling decision currently exists in two flavors, controlled by the +// option use_in_flight_batches_implementation. It is expected that setting this +// option to true will give universally better results; after a period of +// testing to confirm, the old implementation will be removed. // -// The rate (or rather, the corresponding period) is adjusted each time a batch -// is processed, using an exponentially weighted moving average to smooth -// potentially noisy feedback: +// If use_in_flight_batches_implementation is set to true, the scheduler +// limits the number of batches which can be processed concurrently. If a new +// batch is created, and the number of in flight batches is below the limit, +// the next (i.e. oldest) batch is immediately scheduled. Similarly, when a +// batch finishes processing, the limit is rechecked, and another batch may be +// scheduled. To avoid the need to carefully tune the limit for workload, +// model type, platform, etc, it is dynamically adjusted in order to provide the +// lowest latency. +// +// If use_in_flight_batches_implementation is set to false, the scheduler will +// process the oldest batch at an adjustable rate, regardless of batch size. +// The user can provide feedback to help set this rate to achieve some goal +// (i.e. minimize overall latency, limit cpu usage, etc). The rate (or rather, +// the corresponding period) is adjusted each time a batch is processed, using +// an exponentially weighted moving average to smooth noisy feedback: // ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N // period *= (1 + K * emwa_feedback) // @@ -82,6 +101,20 @@ class AdaptiveSharedBatchScheduler int64 num_batch_threads = port::NumSchedulableCPUs(); // The environment to use (typically only overridden by test code). Env* env = Env::Default(); + // Which implementation to use (described in class comments above). + bool use_in_flight_batches_implementation = false; + // Initial limit for number of batches being concurrently processed. + // Non-integer values correspond to probabilistic limits - i.e. a value of + // 3.2 results in an actual cap of 3 80% of the time, and 4 20% of the time. + double initial_in_flight_batches_limit = 3; + // Number of batches between adjustments of in_flight_batches_limit. Larger + // numbers will give less noisy latency measurements, but will be less + // responsive to changes in workload. + int64 batches_to_average_over = 1000; + + // TODO(kte): remove the rate based implementation and corresponding options + // below once testing confirms the superiority of the in flight batches + // implementation. // Initial batch scheduling period in microseconds. Will be altered for // non-zero rate_feedback. double initial_scheduling_period_micros = 500; @@ -122,6 +155,11 @@ class AdaptiveSharedBatchScheduler BatchProcessor process_batch_callback, std::unique_ptr>* queue); + double in_flight_batches_limit() { + mutex_lock l(mu_); + return in_flight_batches_limit_; + } + private: // access to AddBatch, RemoveQueue, GetEnv. friend class internal::ASBSQueue; @@ -129,10 +167,20 @@ class AdaptiveSharedBatchScheduler explicit AdaptiveSharedBatchScheduler(const Options& options); // Batch scheduling function which runs every scheduling_period_ microseconds. + // Only used when options_.use_in_flight_batches_implementation == false. void ProcessOneBatch(); + // Tracks processing latency and adjusts in_flight_batches_limit to minimize. + // Only used when options_.use_in_flight_batches_implementation == true. + void CallbackWrapper(const internal::ASBSBatch* batch, + BatchProcessor callback); + + // Schedules batch if in_flight_batches_limit_ is not met. + // Only used when options_.use_in_flight_batches_implementation == true. + void MaybeScheduleNextBatch() EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Notifies scheduler of non-empty batch which is eligible for processing. - void AddBatch(internal::ASBSBatch*); + void AddBatch(const internal::ASBSBatch* batch); // Removes queue from scheduler. void RemoveQueue(const internal::ASBSQueue* queue); @@ -149,7 +197,8 @@ class AdaptiveSharedBatchScheduler // Collection of batches added by AddBatch, ordered by age. Owned by scheduler // until they are released for processing. std::priority_queue*, - std::vector*>, BatchCompare> + std::vector*>, + BatchCompare> batches_ GUARDED_BY(mu_); // Unowned queues and callbacks added by AddQueue. @@ -160,19 +209,56 @@ class AdaptiveSharedBatchScheduler // Responsible for running ProcessOneBatch. PeriodicFunction was used in order // to check for deletion so that the thread can be shut down. + // Only used when options_.use_in_flight_batches_implementation == false. std::unique_ptr scheduling_thread_; // Responsible for running the batch processing callbacks. std::unique_ptr batch_thread_pool_; // Time interval in microseconds between successive ProcessOneBatch calls. + // Only used when options_.use_in_flight_batches_implementation == false. double scheduling_period_; // Exponentially weighted moving average of // options_.scheduling_period_feedback() evaluated in each ProcessOneBatch // call. + // Only used when options_.use_in_flight_batches_implementation == false. double ewma_feedback_ = 0; + // Limit on number of batches which can be concurrently processed. + // Non-integer values correspond to probabilistic limits - i.e. a value of 3.2 + // results in an actual cap of 3 80% of the time, and 4 20% of the time. + // Only used when options_.use_in_flight_batches_implementation == true. + double in_flight_batches_limit_ GUARDED_BY(mu_); + + // Number of batches currently being processed. + // Only used when options_.use_in_flight_batches_implementation == true. + int64 in_flight_batches_ GUARDED_BY(mu_) = 0; + + // RNG engine and distribution. + // Only used when options_.use_in_flight_batches_implementation == true. + std::default_random_engine rand_engine_; + std::uniform_real_distribution rand_double_; + + // Fields controlling the dynamic adjustment of in_flight_batches_limit_. + // Only used when options_.use_in_flight_batches_implementation == true. + // Number of batches since the last in_flight_batches_limit_ adjustment. + int64 batch_count_ GUARDED_BY(mu_) = 0; + // Sum of processing latency for batches counted by batch_count_. + int64 batch_latency_sum_ GUARDED_BY(mu_) = 0; + // Average batch latency for previous value of in_flight_batches_limit_. + double last_avg_latency_ms_ GUARDED_BY(mu_) = 0; + // Did last_avg_latency_ms_ decrease from the previous last_avg_latency_ms_? + bool last_latency_decreased_ GUARDED_BY(mu_) = false; + // Current direction (+-) to adjust in_flight_batches_limit_ + int step_direction_ GUARDED_BY(mu_) = 1; + // Max adjustment size (as a fraction of in_flight_batches_limit_). + constexpr static double kMaxStepSizeMultiplier = 0.125; // 1/8; + // Min adjustment size (as a fraction of in_flight_batches_limit_). + constexpr static double kMinStepSizeMultiplier = 0.0078125; // 1/128 + // Current adjustment size (as a fraction of in_flight_batches_limit_). + double step_size_multiplier_ GUARDED_BY(mu_) = kMaxStepSizeMultiplier; + TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler); }; @@ -208,6 +294,8 @@ class ASBSQueue : public BatchScheduler { // place any more tasks in this batch. void ReleaseBatch(const ASBSBatch* batch); + size_t max_task_size() const override { return options_.max_batch_size; } + private: std::shared_ptr> scheduler_; const QueueOptions options_; @@ -241,6 +329,12 @@ class ASBSBatch : public Batch { // ---------------- AdaptiveSharedBatchScheduler ---------------- +template +constexpr double AdaptiveSharedBatchScheduler::kMaxStepSizeMultiplier; + +template +constexpr double AdaptiveSharedBatchScheduler::kMinStepSizeMultiplier; + template Status AdaptiveSharedBatchScheduler::Create( const Options& options, @@ -275,6 +369,25 @@ Status AdaptiveSharedBatchScheduler::Create( "feedback_smoothing_batches must be positive; was ", options.feedback_smoothing_batches); } + if (options.initial_in_flight_batches_limit > options.num_batch_threads) { + return errors::InvalidArgument( + "initial_in_flight_batches_limit (", + options.initial_in_flight_batches_limit, + ") should not be larger than num_batch_threads (", + options.num_batch_threads, ")"); + } + if (options.initial_in_flight_batches_limit < 1) { + return errors::InvalidArgument( + "initial_in_flight_batches_limit should be " + "greater than or equal to 1; was ", + options.initial_in_flight_batches_limit); + } + if (options.batches_to_average_over < 1) { + return errors::InvalidArgument( + "batches_to_average_over should be " + "greater than or equal to 1; was ", + options.batches_to_average_over); + } scheduler->reset(new AdaptiveSharedBatchScheduler(options)); return Status::OK(); } @@ -283,14 +396,20 @@ template AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( const Options& options) : options_(options), - scheduling_period_(options.initial_scheduling_period_micros) { + scheduling_period_(options.initial_scheduling_period_micros), + in_flight_batches_limit_(options.initial_in_flight_batches_limit), + rand_double_(0.0, 1.0) { + std::random_device device; + rand_engine_.seed(device()); PeriodicFunction::Options opts; opts.thread_name_prefix = "scheduling_thread"; opts.env = GetEnv(); - scheduling_thread_.reset( - new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts)); batch_thread_pool_.reset(new thread::ThreadPool( GetEnv(), options.thread_pool_name, options.num_batch_threads)); + if (!options.use_in_flight_batches_implementation) { + scheduling_thread_.reset( + new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts)); + } } template @@ -316,9 +435,12 @@ Status AdaptiveSharedBatchScheduler::AddQueue( template void AdaptiveSharedBatchScheduler::AddBatch( - internal::ASBSBatch* batch) { + const internal::ASBSBatch* batch) { mutex_lock l(mu_); batches_.push(batch); + if (options_.use_in_flight_batches_implementation) { + MaybeScheduleNextBatch(); + } } template @@ -328,10 +450,78 @@ void AdaptiveSharedBatchScheduler::RemoveQueue( queues_and_callbacks_.erase(queue); } +template +void AdaptiveSharedBatchScheduler::MaybeScheduleNextBatch() { + if (batches_.empty() || in_flight_batches_ >= in_flight_batches_limit_) + return; + // Non-integer limit handled probabilistially. + if (in_flight_batches_limit_ - in_flight_batches_ < 1 && + rand_double_(rand_engine_) > + (in_flight_batches_limit_ - in_flight_batches_)) + return; + const internal::ASBSBatch* batch = batches_.top(); + batches_.pop(); + // Queue may destroy itself after ReleaseBatch is called. + batch->queue()->ReleaseBatch(batch); + batch_thread_pool_->Schedule( + std::bind(&AdaptiveSharedBatchScheduler::CallbackWrapper, this, + batch, queues_and_callbacks_[batch->queue()])); + in_flight_batches_++; +} + +template +void AdaptiveSharedBatchScheduler::CallbackWrapper( + const internal::ASBSBatch* batch, + AdaptiveSharedBatchScheduler::BatchProcessor callback) { + int64 start_time = batch->creation_time_micros(); + callback(std::unique_ptr>( + const_cast*>(batch))); + int64 end_time = GetEnv()->NowMicros(); + mutex_lock l(mu_); + in_flight_batches_--; + batch_count_++; + batch_latency_sum_ += end_time - start_time; + // Occasionally adjust in_flight_batches_limit_ to minimize average latency. + // Although the optimal value may depend on the workload, the latency should + // be a simple convex function of in_flight_batches_limit_, allowing us to + // locate the global minimum relatively quickly. + if (batch_count_ == options_.batches_to_average_over) { + double current_avg_latency_ms = (batch_latency_sum_ / 1000.) / batch_count_; + bool current_latency_decreased = + current_avg_latency_ms < last_avg_latency_ms_; + if (current_latency_decreased) { + // If latency improvement was because we're moving in the correct + // direction, increase step_size so that we can get to the minimum faster. + // If latency improvement was due to backtracking from a previous failure, + // decrease step_size in order to refine our location. + step_size_multiplier_ *= (last_latency_decreased_ ? 2 : 0.5); + step_size_multiplier_ = + std::min(step_size_multiplier_, kMaxStepSizeMultiplier); + step_size_multiplier_ = + std::max(step_size_multiplier_, kMinStepSizeMultiplier); + } else { + // Return (nearly) to previous position and confirm that latency is better + // there before decreasing step size. + step_direction_ = -step_direction_; + } + in_flight_batches_limit_ += + step_direction_ * in_flight_batches_limit_ * step_size_multiplier_; + in_flight_batches_limit_ = + std::min(in_flight_batches_limit_, + static_cast(options_.num_batch_threads)); + in_flight_batches_limit_ = std::max(in_flight_batches_limit_, 1.0); + last_avg_latency_ms_ = current_avg_latency_ms; + last_latency_decreased_ = current_latency_decreased; + batch_count_ = 0; + batch_latency_sum_ = 0; + } + MaybeScheduleNextBatch(); +} + template void AdaptiveSharedBatchScheduler::ProcessOneBatch() { static const double kFeedbackMultiplier = .001; - internal::ASBSBatch* batch = nullptr; + const internal::ASBSBatch* batch = nullptr; BatchProcessor callback; const int64 start_time_micros = GetEnv()->NowMicros(); { @@ -355,7 +545,8 @@ void AdaptiveSharedBatchScheduler::ProcessOneBatch() { // Queue may destroy itself after ReleaseBatch is called. batch->queue()->ReleaseBatch(batch); batch_thread_pool_->Schedule([callback, batch] { - callback(std::unique_ptr>(batch)); + callback(std::unique_ptr>( + const_cast*>(batch))); }); } const int64 sleep_time = @@ -425,6 +616,7 @@ Status ASBSQueue::Schedule(std::unique_ptr* task) { current_batch_->AddTask(std::move(*task)); num_enqueued_tasks_++; } + // AddBatch must be called outside of lock, since it may call ReleaseBatch. if (new_batch != nullptr) scheduler_->AddBatch(new_batch); return Status::OK(); } diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc index a07cd6d834fa28904bf7748b16972cca217503c1..18f1e554525a306ffe07460a889411ed4755b89f 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc @@ -141,6 +141,16 @@ TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) { options = Scheduler::Options(); options.feedback_smoothing_batches = 0; EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.initial_in_flight_batches_limit = 0.5; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.num_batch_threads = 5; + options.initial_in_flight_batches_limit = 8; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.batches_to_average_over = -5; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); } TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { @@ -186,6 +196,7 @@ TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { queue_options.max_enqueued_batches = 2; TF_ASSERT_OK( scheduler->AddQueue(queue_options, queue_0_callback, &queue_0)); + EXPECT_EQ(10, queue_0->max_task_size()); queue_options.max_batch_size = 0; // Queue must have max_batch_size > 0. EXPECT_FALSE( @@ -433,6 +444,107 @@ TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) { } stop_teardown.Notify(); } + +TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesImplementation) { + AdaptiveSharedBatchScheduler::Options options; + options.use_in_flight_batches_implementation = true; + options.initial_in_flight_batches_limit = 2; + options.batches_to_average_over = 1000; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + mutex mu; + int processed_batches = 0; + Notification finish_processing; + auto queue_callback = [&mu, &processed_batches, &finish_processing]( + std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + mu.lock(); + int batch_num = ++processed_batches; + mu.unlock(); + if (batch_num == 2) { + // Give third batch a chance to process if it's going to. + Env::Default()->SleepForMicroseconds(1000); + finish_processing.Notify(); + } + if (batch_num == 3) { + ASSERT_TRUE(finish_processing.HasBeenNotified()); + } + finish_processing.WaitForNotification(); + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Enqueue 3 batches. + for (int i = 0; i < 3; i++) { + TF_ASSERT_OK(ScheduleTask(100, queue.get())); + } +} + +TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesLimitTuning) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + AdaptiveSharedBatchScheduler::Options options; + options.env = &env; + options.use_in_flight_batches_implementation = true; + options.initial_in_flight_batches_limit = 2; + options.batches_to_average_over = 1; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + auto queue_callback = [&env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + switch (batch->size()) { + case 0: + env.AdvanceByMicroseconds(10); + break; + case 1: + env.AdvanceByMicroseconds(15); + break; + case 2: + env.AdvanceByMicroseconds(10); + break; + case 3: + env.AdvanceByMicroseconds(11); + break; + } + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + TF_ASSERT_OK(ScheduleTask(0, queue.get())); + double in_flight_batches_limit = 2; + while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) { + } + // Initial direction will be negative. + EXPECT_LT(scheduler->in_flight_batches_limit(), in_flight_batches_limit); + in_flight_batches_limit = scheduler->in_flight_batches_limit(); + TF_ASSERT_OK(ScheduleTask(1, queue.get())); + while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) { + } + // Latency increased -> change direction. + EXPECT_GT(scheduler->in_flight_batches_limit(), in_flight_batches_limit); + in_flight_batches_limit = scheduler->in_flight_batches_limit(); + TF_ASSERT_OK(ScheduleTask(2, queue.get())); + while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) { + } + // Latency decreased -> keep going in same direction. + EXPECT_GT(scheduler->in_flight_batches_limit(), in_flight_batches_limit); + in_flight_batches_limit = scheduler->in_flight_batches_limit(); + TF_ASSERT_OK(ScheduleTask(3, queue.get())); + while (scheduler->in_flight_batches_limit() == in_flight_batches_limit) { + } + // Latency increased -> change direction. + EXPECT_LT(scheduler->in_flight_batches_limit(), in_flight_batches_limit); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} } // namespace anonymous } // namespace serving } // namespace tensorflow diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h index 9d3805fbaf39978159dd2f4a754e6d41a07acf6a..91065db2499dffd2687a53bd6304d9b7593f7b3a 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler.h +++ b/tensorflow/contrib/batching/basic_batch_scheduler.h @@ -192,6 +192,10 @@ class BasicBatchScheduler : public BatchScheduler { size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; + size_t max_task_size() const override { + return shared_scheduler_queue_->max_task_size(); + } + private: explicit BasicBatchScheduler( std::unique_ptr> shared_scheduler_queue); diff --git a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc index e020301795c7dadee2815c0e0d727e53e5fb9e6e..187823151cf840dcf8058677fcf74d1beffc3bc2 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc @@ -73,6 +73,7 @@ TEST(BasicBatchSchedulerTest, Basic) { std::unique_ptr> scheduler; TF_ASSERT_OK( BasicBatchScheduler::Create(options, callback, &scheduler)); + EXPECT_EQ(10, scheduler->max_task_size()); EXPECT_EQ(0, scheduler->NumEnqueuedTasks()); EXPECT_EQ(3 * 10, scheduler->SchedulingCapacity()); TF_ASSERT_OK(ScheduleTask(3, scheduler.get())); diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index a5072f439abad3c5db79a514a7f2baff0b021b39..e18cf6c35059e4d720768e3b2c02b03727a6bac4 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -178,6 +178,10 @@ class BatchScheduler { // This method is useful for monitoring, or for guaranteeing a future slot in // the schedule (but being mindful about the caveats listed above). virtual size_t SchedulingCapacity() const = 0; + + // Returns the maximum allowed size of tasks submitted to the scheduler. (This + // is typically equal to a configured maximum batch size.) + virtual size_t max_task_size() const = 0; }; ////////// diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h index 1853827dc0433869d7e38870d83e007bb0cb1bb1..86c45bdc2e66e30fbde15f6cafe481cf969c14d0 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/shared_batch_scheduler.h @@ -248,6 +248,9 @@ class Queue { // BatchScheduler::SchedulingCapacity(). size_t SchedulingCapacity() const; + // Returns the maximum allowed size of tasks submitted to the queue. + size_t max_task_size() const { return options_.max_batch_size; } + // Called by a thread that is ready to process a batch, to request one from // this queue. Either returns a batch that is ready to be processed, or // nullptr if the queue declines to schedule a batch at this time. If it @@ -338,6 +341,8 @@ class QueueHandle : public BatchScheduler { size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; + size_t max_task_size() const override { return queue_->max_task_size(); } + private: // The scheduler that owns 'queue_'. std::shared_ptr> scheduler_; diff --git a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc index 3e924ae5f13519b4fe9a3f4b510773ca2bddaf23..3ac79a8fdc47389816db8ca09f27846d1c4623c2 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc @@ -429,6 +429,7 @@ TEST(SharedBatchSchedulerTest, ConstMethods) { queue_options.max_enqueued_batches = max_enqueued_batches; std::unique_ptr> queue; TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue)); + EXPECT_EQ(2, queue->max_task_size()); EXPECT_EQ(0, queue->NumEnqueuedTasks()); EXPECT_EQ(max_enqueued_batches * 2, queue->SchedulingCapacity()); diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 213ae01c3bf69adf7514ade560fd055b0bb3fe7d..a262d4aecdbb69dfcd8b88bc0a09060500d6b1c9 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -19,9 +19,9 @@ py_library( srcs = ["__init__.py"] + glob(["python/ops/*.py"]), srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:functional_ops", @@ -32,12 +32,8 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:random_ops", "//tensorflow/python:state_ops", - "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/distributions", "//third_party/py/numpy", - "@six_archive//:six", ], ) @@ -103,6 +99,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "layers_dense_variational_test", + size = "small", + srcs = ["python/kernel_tests/layers_dense_variational_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + ], +) + cuda_py_test( name = "monte_carlo_test", size = "small", @@ -124,6 +139,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "halton_sequence_test", + size = "small", + srcs = ["python/kernel_tests/halton_sequence_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "hmc_test", size = "medium", @@ -145,6 +179,27 @@ cuda_py_test( ], ) +cuda_py_test( + name = "sgld_optimizer_test", + size = "small", + srcs = ["python/kernel_tests/sgld_optimizer_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index b98bc369542679b05169db092aee86e884ca1625..95b9452b1ada60c44672f37800ced2133d2bd8b2 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -23,16 +23,30 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence from tensorflow.contrib.bayesflow.python.ops import custom_grad +from tensorflow.contrib.bayesflow.python.ops import halton_sequence from tensorflow.contrib.bayesflow.python.ops import hmc +from tensorflow.contrib.bayesflow.python.ops import layers from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings from tensorflow.contrib.bayesflow.python.ops import monte_carlo +from tensorflow.contrib.bayesflow.python.ops import optimizers # pylint: enable=unused-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['csiszar_divergence', 'custom_grad', 'entropy', - 'metropolis_hastings', 'monte_carlo', 'hmc', 'special_math', - 'stochastic_variables', 'variational_inference'] +_allowed_symbols = [ + 'csiszar_divergence', + 'custom_grad', + 'entropy', + 'halton_sequence', + 'hmc', + 'layers', + 'metropolis_hastings', + 'monte_carlo', + 'optimizers', + 'special_math', + 'stochastic_variables', + 'variational_inference', +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0a85862abfd744a86b9a38e10dbb5b985d0a0e94 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for halton_sequence.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import halton_sequence as halton +from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +mc = monte_carlo_lib + + +class HaltonSequenceTest(test.TestCase): + + def test_known_values_small_bases(self): + with self.test_session(): + # The first five elements of the Halton sequence with base 2 and 3 + expected = np.array(((1. / 2, 1. / 3), + (1. / 4, 2. / 3), + (3. / 4, 1. / 9), + (1. / 8, 4. / 9), + (5. / 8, 7. / 9)), dtype=np.float32) + sample = halton.sample(2, num_samples=5) + self.assertAllClose(expected, sample.eval(), rtol=1e-6) + + def test_sample_indices(self): + with self.test_session(): + dim = 5 + indices = math_ops.range(10, dtype=dtypes.int32) + sample_direct = halton.sample(dim, num_samples=10) + sample_from_indices = halton.sample(dim, sample_indices=indices) + self.assertAllClose(sample_direct.eval(), sample_from_indices.eval(), + rtol=1e-6) + + def test_dtypes_works_correctly(self): + with self.test_session(): + dim = 3 + sample_float32 = halton.sample(dim, num_samples=10, dtype=dtypes.float32) + sample_float64 = halton.sample(dim, num_samples=10, dtype=dtypes.float64) + self.assertEqual(sample_float32.eval().dtype, np.float32) + self.assertEqual(sample_float64.eval().dtype, np.float64) + + def test_normal_integral_mean_and_var_correctly_estimated(self): + n = int(1000) + # This test is almost identical to the similarly named test in + # monte_carlo_test.py. The only difference is that we use the Halton + # samples instead of the random samples to evaluate the expectations. + # MC with pseudo random numbers converges at the rate of 1/ Sqrt(N) + # (N=number of samples). For QMC in low dimensions, the expected convergence + # rate is ~ 1/N. Hence we should only need 1e3 samples as compared to the + # 1e6 samples used in the pseudo-random monte carlo. + with self.test_session(): + mu_p = array_ops.constant([-1.0, 1.0], dtype=dtypes.float64) + mu_q = array_ops.constant([0.0, 0.0], dtype=dtypes.float64) + sigma_p = array_ops.constant([0.5, 0.5], dtype=dtypes.float64) + sigma_q = array_ops.constant([1.0, 1.0], dtype=dtypes.float64) + p = normal_lib.Normal(loc=mu_p, scale=sigma_p) + q = normal_lib.Normal(loc=mu_q, scale=sigma_q) + + cdf_sample = halton.sample(2, num_samples=n, dtype=dtypes.float64) + q_sample = q.quantile(cdf_sample) + + # Compute E_p[X]. + e_x = mc.expectation_importance_sampler( + f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, + seed=42) + + # Compute E_p[X^2]. + e_x2 = mc.expectation_importance_sampler( + f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, + seed=42) + + stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x)) + # Keep the tolerance levels the same as in monte_carlo_test.py. + self.assertEqual(p.batch_shape, e_x.get_shape()) + self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01) + self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02) + + def test_docstring_example(self): + # Produce the first 1000 members of the Halton sequence in 3 dimensions. + num_samples = 1000 + dim = 3 + with self.test_session(): + sample = halton.sample(dim, num_samples=num_samples) + + # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional + # hypercube. + powers = math_ops.range(1.0, limit=dim + 1) + integral = math_ops.reduce_mean( + math_ops.reduce_prod(sample ** powers, axis=-1)) + true_value = 1.0 / math_ops.reduce_prod(powers + 1.0) + + # Produces a relative absolute error of 1.7%. + self.assertAllClose(integral.eval(), true_value.eval(), rtol=0.02) + + # Now skip the first 1000 samples and recompute the integral with the next + # thousand samples. The sample_indices argument can be used to do this. + + sample_indices = math_ops.range(start=1000, limit=1000 + num_samples, + dtype=dtypes.int32) + sample_leaped = halton.sample(dim, sample_indices=sample_indices) + + integral_leaped = math_ops.reduce_mean( + math_ops.reduce_prod(sample_leaped ** powers, axis=-1)) + self.assertAllClose(integral_leaped.eval(), true_value.eval(), rtol=0.001) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index b1f108e5f01e4945ee83d8262f1d99877f0fe9f0..cbc66b6dc13db62c25952de6b6c13b2fdfe27f12 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Hamiltonian Monte Carlo. -""" +"""Tests for Hamiltonian Monte Carlo.""" from __future__ import absolute_import from __future__ import division @@ -27,6 +26,7 @@ from tensorflow.contrib.bayesflow.python.ops import hmc from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -46,6 +46,9 @@ class HMCTest(test.TestCase): random_seed.set_random_seed(10003) np.random.seed(10003) + def assertAllFinite(self, x): + self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x)) + def _log_gamma_log_prob(self, x, event_dims=()): """Computes log-pdf of a log-gamma random variable. @@ -345,5 +348,97 @@ class HMCTest(test.TestCase): def testAIS12(self): self._ais_gets_correct_log_normalizer_wrapper([1, 2]) + def testNanRejection(self): + """Tests that an update that yields NaN potentials gets rejected. + + We run HMC with a target distribution that returns NaN + log-likelihoods if any element of x < 0, and unit-scale + exponential log-likelihoods otherwise. The exponential potential + pushes x towards 0, ensuring that any reasonably large update will + push us over the edge into NaN territory. + """ + def _unbounded_exponential_log_prob(x): + """An exponential distribution with log-likelihood NaN for x < 0.""" + per_element_potentials = array_ops.where(x < 0, + np.nan * array_ops.ones_like(x), + -x) + return math_ops.reduce_sum(per_element_potentials) + + with self.test_session() as sess: + initial_x = math_ops.linspace(0.01, 5, 10) + updated_x, acceptance_probs, _, _ = hmc.kernel( + 2., 5, initial_x, _unbounded_exponential_log_prob, [0]) + initial_x_val, updated_x_val, acceptance_probs_val = sess.run( + [initial_x, updated_x, acceptance_probs]) + + logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) + logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) + logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) + + self.assertAllEqual(initial_x_val, updated_x_val) + self.assertEqual(acceptance_probs_val, 0.) + + def testNanFromGradsDontPropagate(self): + """Test that update with NaN gradients does not cause NaN in results.""" + def _nan_log_prob_with_nan_gradient(x): + return np.nan * math_ops.reduce_sum(x) + + with self.test_session() as sess: + initial_x = math_ops.linspace(0.01, 5, 10) + updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel( + 2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0]) + initial_x_val, updated_x_val, acceptance_probs_val = sess.run( + [initial_x, updated_x, acceptance_probs]) + + logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) + logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) + logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) + + self.assertAllEqual(initial_x_val, updated_x_val) + self.assertEqual(acceptance_probs_val, 0.) + + self.assertAllFinite( + gradients_impl.gradients(updated_x, initial_x)[0].eval()) + self.assertTrue( + gradients_impl.gradients(new_grad, initial_x)[0] is None) + + # Gradients of the acceptance probs and new log prob are not finite. + _ = new_log_prob # Prevent unused arg error. + # self.assertAllFinite( + # gradients_impl.gradients(acceptance_probs, initial_x)[0].eval()) + # self.assertAllFinite( + # gradients_impl.gradients(new_log_prob, initial_x)[0].eval()) + + def testChainWorksIn64Bit(self): + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float64(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float64), + target_log_prob_fn=log_prob, + event_dims=[-1]) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float64, states_.dtype) + self.assertEqual(np.float64, acceptance_probs_.dtype) + + def testChainWorksIn16Bit(self): + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float16(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float16), + target_log_prob_fn=log_prob, + event_dims=[-1]) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float16, states_.dtype) + self.assertEqual(np.float16, acceptance_probs_.dtype) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py new file mode 100644 index 0000000000000000000000000000000000000000..50358fd1c2b7635ffe2d08c5af3219bb0a11498b --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py @@ -0,0 +1,304 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dense Bayesian layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +class Counter(object): + """Helper class to manage incrementing a counting `int`.""" + + def __init__(self): + self._value = -1 + + @property + def value(self): + return self._value + + def __call__(self): + self._value += 1 + return self._value + + +class MockDistribution(normal_lib.Normal): + """Monitors DenseVariational calls to the underlying distribution.""" + + def __init__(self, result_sample, result_log_prob, loc=None, scale=None): + self.result_sample = result_sample + self.result_log_prob = result_log_prob + self.result_loc = loc + self.result_scale = scale + self.called_log_prob = Counter() + self.called_sample = Counter() + self.called_loc = Counter() + self.called_scale = Counter() + + def log_prob(self, *args, **kwargs): + self.called_log_prob() + return self.result_log_prob + + def sample(self, *args, **kwargs): + self.called_sample() + return self.result_sample + + @property + def loc(self): + self.called_loc() + return self.result_loc + + @property + def scale(self): + self.called_scale() + return self.result_scale + + +class MockKLDivergence(object): + """Monitors DenseVariational calls to the divergence implementation.""" + + def __init__(self, result): + self.result = result + self.args = [] + self.called = Counter() + + def __call__(self, *args, **kwargs): + self.called() + self.args.append(args) + return self.result + + +class DenseVariationalLocalReparametrization(test.TestCase): + + def testKLPenaltyKernel(self): + with self.test_session(): + dense_vi = prob_layers_lib.DenseVariational(units=2) + inputs = random_ops.random_uniform([2, 3], seed=1) + + # No keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 0) + self.assertListEqual(dense_vi.losses, loss_keys) + + _ = dense_vi(inputs) + + # Yes keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 1) + self.assertListEqual(dense_vi.losses, loss_keys) + + def testKLPenaltyBoth(self): + def _make_normal(dtype, *args): # pylint: disable=unused-argument + return normal_lib.Normal( + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) + with self.test_session(): + dense_vi = prob_layers_lib.DenseVariational( + units=2, + bias_posterior_fn=prob_layers_lib.default_mean_field_normal_fn(), + bias_prior_fn=_make_normal) + inputs = random_ops.random_uniform([2, 3], seed=1) + + # No keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 0) + self.assertListEqual(dense_vi.losses, loss_keys) + + _ = dense_vi(inputs) + + # Yes keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 2) + self.assertListEqual(dense_vi.losses, loss_keys) + + def testVariationalNonLocal(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_size = [in_size, out_size] + kernel_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_size, seed=seed())) + + bias_size = [out_size] + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + expected_outputs = ( + math_ops.matmul(inputs, kernel_posterior.result_sample) + + bias_posterior.result_sample) + + dense_vi = prob_layers_lib.DenseVariational( + units=2, + kernel_use_local_reparameterization=False, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence) + + outputs = dense_vi(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + + [ + expected_outputs_, actual_outputs_, + expected_kernel_, actual_kernel_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_posterior.result_sample, dense_vi.kernel.posterior_tensor, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_kernel_, actual_kernel_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior, kernel_prior, kernel_posterior.result_sample]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior, bias_prior, bias_posterior.result_sample]], + bias_divergence.args) + + def testVariationalLocal(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_size = [in_size, out_size] + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform(kernel_size, seed=seed()), + scale=random_ops.random_uniform(kernel_size, seed=seed()), + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_size, seed=seed())) + + bias_size = [out_size] + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + expected_kernel_posterior_affine = normal_lib.Normal( + loc=math_ops.matmul(inputs, kernel_posterior.result_loc), + scale=math_ops.matmul( + inputs**2., kernel_posterior.result_scale**2)**0.5) + expected_kernel_posterior_affine_tensor = ( + expected_kernel_posterior_affine.sample(seed=42)) + expected_outputs = (expected_kernel_posterior_affine_tensor + + bias_posterior.result_sample) + + dense_vi = prob_layers_lib.DenseVariational( + units=2, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence) + + outputs = dense_vi(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + + [ + expected_outputs_, actual_outputs_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior, kernel_prior, None]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior, bias_prior, bias_posterior.result_sample]], + bias_divergence.args) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66793383fdd5c71f136900197a91be6966e2f8c7 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py @@ -0,0 +1,209 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional test for GradientDescent.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import math +from tensorflow.contrib.bayesflow.python.ops.optimizers import SGLDOptimizer +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class SGLDOptimizerTest(test.TestCase): + + def testBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.53 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + + def testBasicMultiInstance(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + vara = variables.Variable([1.1, 2.1], dtype=dtype) + varb = variables.Variable([3.0, 4.0], dtype=dtype) + gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) + gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.5 + sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) + sgd_op = sgd_optimizer.apply_gradients( + zip([grads0, grads1], [var0, var1])) + sgd_optimizer2 = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate) + sgd_op2 = sgd_optimizer2.apply_gradients( + zip([gradsa, gradsb], [vara, varb])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) + + # Run 1 step of sgd + sgd_op.run() + sgd_op2.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], vara.eval()) + + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], varb.eval()) + self.assertNotEqual(sgd_optimizer.variable_scope, + sgd_optimizer2.variable_scope) + self.assertNotEqual(sgd_optimizer.variable_scope.name, + sgd_optimizer2.variable_scope.name) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + lrate = constant_op.constant(3.0) + decay_rate = 0.5 + sgd_op = SGLDOptimizer( + lrate, preconditioner_decay_rate=constant_op.constant( + decay_rate)).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + + def testGradWrtRef(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + opt = SGLDOptimizer(3.0) + values = [1.0, 3.0] + vars_ = [variables.Variable([v], dtype=dtype) for v in values] + grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) + variables.global_variables_initializer().run() + for grad, _ in grads_and_vars: + self.assertAllCloseAccordingToType([1.0], grad.eval()) + + def testWithGlobalStep(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + global_step = variables.Variable(0, trainable=False) + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.1 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + + # Validate updated params and global_step + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + self.assertAllCloseAccordingToType(1, global_step.eval()) + + def testSparseBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) + var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) + grads0 = ops.IndexedSlices( + constant_op.constant([0.1], shape=[1, 1], dtype=dtype), + constant_op.constant([0]), constant_op.constant([2, 1])) + grads1 = ops.IndexedSlices( + constant_op.constant([0.01], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), constant_op.constant([2, 1])) + decay_rate = 0.9 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) + self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType([[1.1 - 3.0 * grads_scaled], [2.1]], + var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [[3.0 - 3.0 * 0], [4.0 - 3.0 * grads_scaled]], var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_impl.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py similarity index 51% rename from tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_impl.py rename to tensorflow/contrib/bayesflow/python/ops/halton_sequence.py index a640dfe7dfbcce96261589c7fc49107deaefdd54..49d747d538f5a4aa3134d28ba00a651cb509fa41 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py @@ -12,37 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sigmoid bijector.""" +"""Support for low discrepancy Halton sequences. + +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector - +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.halton_sequence_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented -__all__ = [ - "Sigmoid", +_allowed_symbols = [ + 'sample', ] - -class Sigmoid(bijector.Bijector): - """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" - - def __init__(self, validate_args=False, name="sigmoid"): - super(Sigmoid, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) - - def _forward(self, x): - return math_ops.sigmoid(x) - - def _inverse(self, y): - return math_ops.log(y) - math_ops.log1p(-y) - - def _inverse_log_det_jacobian(self, y): - return -math_ops.log(y) - math_ops.log1p(-y) - - def _forward_log_det_jacobian(self, x): - return -nn_ops.softplus(-x) - nn_ops.softplus(x) +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..8cabf18903b5f15002470acdfb8fdd3ec31a7413 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py @@ -0,0 +1,264 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Quasi Monte Carlo support: Halton sequence. + +@@sample +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +__all__ = [ + 'sample', +] + + +# The maximum dimension we support. This is limited by the number of primes +# in the _PRIMES array. +_MAX_DIMENSION = 1000 + + +def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None): + r"""Returns a sample from the `m` dimensional Halton sequence. + + Warning: The sequence elements take values only between 0 and 1. Care must be + taken to appropriately transform the domain of a function if it differs from + the unit cube before evaluating integrals using Halton samples. It is also + important to remember that quasi-random numbers are not a replacement for + pseudo-random numbers in every context. Quasi random numbers are completely + deterministic and typically have significant negative autocorrelation (unless + randomized). + + Computes the members of the low discrepancy Halton sequence in dimension + `dim`. The d-dimensional sequence takes values in the unit hypercube in d + dimensions. Currently, only dimensions up to 1000 are supported. The prime + base for the `k`-th axes is the k-th prime starting from 2. For example, + if dim = 3, then the bases will be [2, 3, 5] respectively and the first + element of the sequence will be: [0.5, 0.333, 0.2]. For a more complete + description of the Halton sequences see: + https://en.wikipedia.org/wiki/Halton_sequence. For low discrepancy sequences + and their applications see: + https://en.wikipedia.org/wiki/Low-discrepancy_sequence. + + The user must supply either `num_samples` or `sample_indices` but not both. + The former is the number of samples to produce starting from the first + element. If `sample_indices` is given instead, the specified elements of + the sequence are generated. For example, sample_indices=tf.range(10) is + equivalent to specifying n=10. + + Example Use: + + ```python + bf = tf.contrib.bayesflow + + # Produce the first 1000 members of the Halton sequence in 3 dimensions. + num_samples = 1000 + dim = 3 + sample = bf.halton_sequence.sample(dim, num_samples=num_samples) + + # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional + # hypercube. + powers = tf.range(1.0, limit=dim + 1) + integral = tf.reduce_mean(tf.reduce_prod(sample ** powers, axis=-1)) + true_value = 1.0 / tf.reduce_prod(powers + 1.0) + with tf.Session() as session: + values = session.run((integral, true_value)) + + # Produces a relative absolute error of 1.7%. + print ("Estimated: %f, True Value: %f" % values) + + # Now skip the first 1000 samples and recompute the integral with the next + # thousand samples. The sample_indices argument can be used to do this. + + + sample_indices = tf.range(start=1000, limit=1000 + num_samples, + dtype=tf.int32) + sample_leaped = halton.sample(dim, sample_indices=sample_indices) + + integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers, + axis=-1)) + with tf.Session() as session: + values = session.run((integral_leaped, true_value)) + # Now produces a relative absolute error of 0.05%. + print ("Leaped Estimated: %f, True Value: %f" % values) + ``` + + Args: + dim: Positive Python `int` representing each sample's `event_size.` Must + not be greater than 1000. + num_samples: (Optional) positive Python `int`. The number of samples to + generate. Either this parameter or sample_indices must be specified but + not both. If this parameter is None, then the behaviour is determined by + the `sample_indices`. + sample_indices: (Optional) `Tensor` of dtype int32 and rank 1. The elements + of the sequence to compute specified by their position in the sequence. + The entries index into the Halton sequence starting with 0 and hence, + must be whole numbers. For example, sample_indices=[0, 5, 6] will produce + the first, sixth and seventh elements of the sequence. If this parameter + is None, then the `num_samples` parameter must be specified which gives + the number of desired samples starting from the first sample. + dtype: (Optional) The dtype of the sample. One of `float32` or `float64`. + Default is `float32`. + name: (Optional) Python `str` describing ops managed by this function. If + not supplied the name of this function is used. + + Returns: + halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype + and `shape` `[num_samples, dim]` if `num_samples` was specified or shape + `[s, dim]` where s is the size of `sample_indices` if `sample_indices` + were specified. + + Raises: + ValueError: if both `sample_indices` and `num_samples` were specified or + if dimension `dim` is less than 1 or greater than 1000. + """ + if dim < 1 or dim > _MAX_DIMENSION: + raise ValueError( + 'Dimension must be between 1 and {}. Supplied {}'.format(_MAX_DIMENSION, + dim)) + if (num_samples is None) == (sample_indices is None): + raise ValueError('Either `num_samples` or `sample_indices` must be' + ' specified but not both.') + + dtype = dtype or dtypes.float32 + if not dtype.is_floating: + raise ValueError('dtype must be of `float`-type') + + with ops.name_scope(name, 'sample', values=[sample_indices]): + # Here and in the following, the shape layout is as follows: + # [sample dimension, event dimension, coefficient dimension]. + # The coefficient dimension is an intermediate axes which will hold the + # weights of the starting integer when expressed in the (prime) base for + # an event dimension. + indices = _get_indices(num_samples, sample_indices, dtype) + radixes = array_ops.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1]) + + max_sizes_by_axes = _base_expansion_size(math_ops.reduce_max(indices), + radixes) + + max_size = math_ops.reduce_max(max_sizes_by_axes) + + # The powers of the radixes that we will need. Note that there is a bit + # of an excess here. Suppose we need the place value coefficients of 7 + # in base 2 and 3. For 2, we will have 3 digits but we only need 2 digits + # for base 3. However, we can only create rectangular tensors so we + # store both expansions in a [2, 3] tensor. This leads to the problem that + # we might end up attempting to raise large numbers to large powers. For + # example, base 2 expansion of 1024 has 10 digits. If we were in 10 + # dimensions, then the 10th prime (29) we will end up computing 29^10 even + # though we don't need it. We avoid this by setting the exponents for each + # axes to 0 beyond the maximum value needed for that dimension. + exponents_by_axes = array_ops.tile([math_ops.range(max_size)], [dim, 1]) + weight_mask = exponents_by_axes > max_sizes_by_axes + capped_exponents = array_ops.where( + weight_mask, array_ops.zeros_like(exponents_by_axes), exponents_by_axes) + weights = radixes ** capped_exponents + coeffs = math_ops.floor_div(indices, weights) + coeffs *= 1 - math_ops.cast(weight_mask, dtype) + coeffs = (coeffs % radixes) / radixes + return math_ops.reduce_sum(coeffs / weights, axis=-1) + + +def _get_indices(n, sample_indices, dtype, name=None): + """Generates starting points for the Halton sequence procedure. + + The k'th element of the sequence is generated starting from a positive integer + which must be distinct for each `k`. It is conventional to choose the starting + point as `k` itself (or `k+1` if k is zero based). This function generates + the starting integers for the required elements and reshapes the result for + later use. + + Args: + n: Positive `int`. The number of samples to generate. If this + parameter is supplied, then `sample_indices` should be None. + sample_indices: `Tensor` of dtype int32 and rank 1. The entries + index into the Halton sequence starting with 0 and hence, must be whole + numbers. For example, sample_indices=[0, 5, 6] will produce the first, + sixth and seventh elements of the sequence. If this parameter is not None + then `n` must be None. + dtype: The dtype of the sample. One of `float32` or `float64`. + Default is `float32`. + name: Python `str` name which describes ops created by this function. + + Returns: + indices: `Tensor` of dtype `dtype` and shape = `[n, 1, 1]`. + """ + with ops.name_scope(name, 'get_indices', [n, sample_indices]): + if sample_indices is None: + sample_indices = math_ops.range(n, dtype=dtype) + else: + sample_indices = math_ops.cast(sample_indices, dtype) + + # Shift the indices so they are 1 based. + indices = sample_indices + 1 + + # Reshape to make space for the event dimension and the place value + # coefficients. + return array_ops.reshape(indices, [-1, 1, 1]) + + +def _base_expansion_size(num, bases): + """Computes the number of terms in the place value expansion. + + Let num = a0 + a1 b + a2 b^2 + ... ak b^k be the place value expansion of + `num` in base b (ak <> 0). This function computes and returns `k` for each + base `b` specified in `bases`. + + This can be inferred from the base `b` logarithm of `num` as follows: + $$k = Floor(log_b (num)) + 1 = Floor( log(num) / log(b)) + 1$$ + + Args: + num: Scalar `Tensor` of dtype either `float32` or `float64`. The number to + compute the base expansion size of. + bases: `Tensor` of the same dtype as num. The bases to compute the size + against. + + Returns: + Tensor of same dtype and shape as `bases` containing the size of num when + written in that base. + """ + return math_ops.floor(math_ops.log(num) / math_ops.log(bases)) + 1 + + +def _primes_less_than(n): + # Based on + # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188 + """Returns sorted array of primes such that `2 <= prime < n`.""" + small_primes = np.array((2, 3, 5)) + if n <= 6: + return small_primes[small_primes < n] + sieve = np.ones(n // 3 + (n % 6 == 2), dtype=np.bool) + sieve[0] = False + m = int(n ** 0.5) // 3 + 1 + for i in range(m): + if not sieve[i]: + continue + k = 3 * i + 1 | 1 + sieve[k ** 2 // 3::2 * k] = False + sieve[(k ** 2 + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = False + return np.r_[2, 3, 3 * np.nonzero(sieve)[0] + 1 | 1] + +_PRIMES = _primes_less_than(7919+1) + +assert len(_PRIMES) == _MAX_DIMENSION diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index 333dce929530adceb30dcb63653a5bd009c059e0..5685a942e98800a39ec718adc67bcfd43aeafd52 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -27,6 +27,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -174,9 +175,11 @@ def chain(n_iterations, step_size, n_leapfrog_steps, initial_x, potential_and_grad = _make_potential_and_grad(target_log_prob_fn) potential, grad = potential_and_grad(initial_x) - return functional_ops.scan(body, array_ops.zeros(n_iterations), - (initial_x, array_ops.zeros(non_event_shape), - -potential, -grad))[:2] + return functional_ops.scan( + body, array_ops.zeros(n_iterations, dtype=initial_x.dtype), + (initial_x, + array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + -potential, -grad))[:2] def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, @@ -298,8 +301,9 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, return updated_x, acceptance_probs, w x, acceptance_probs, w = functional_ops.scan( - _body, beta_series, (initial_x, array_ops.zeros(non_event_shape), - array_ops.zeros(non_event_shape))) + _body, beta_series, + (initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + array_ops.zeros(non_event_shape, dtype=initial_x.dtype))) return w[-1], x[-1], acceptance_probs[-1] @@ -446,9 +450,10 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), """ with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]): potential_and_grad = _make_potential_and_grad(target_log_prob_fn) + x = ops.convert_to_tensor(x, name='x') x_shape = array_ops.shape(x) - m = random_ops.random_normal(x_shape) + m = random_ops.random_normal(x_shape, dtype=x.dtype) kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims) @@ -468,26 +473,33 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims) - # TODO(mhoffman): It seems like there may be an opportunity for nans here. - # I'm delaying addressing this because we're going to refactor this part - # to use the more general Metropolis abstraction anyway. - acceptance_probs = math_ops.exp(math_ops.minimum(0., log_potential_0 - - log_potential_1 + - kinetic_0 - kinetic_1)) - accepted = math_ops.cast( - random_ops.random_uniform(array_ops.shape(acceptance_probs)) < - acceptance_probs, np.float32) - new_log_prob = (-log_potential_0 * (1. - accepted) - - log_potential_1 * accepted) + energy_change = log_potential_1 - log_potential_0 + kinetic_1 - kinetic_0 + # Treat NaN as infinite energy (and therefore guaranteed rejection). + energy_change = array_ops.where( + math_ops.is_nan(energy_change), + array_ops.fill(array_ops.shape(energy_change), + energy_change.dtype.as_numpy_dtype(np.inf)), + energy_change) + acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.)) + accepted = ( + random_ops.random_uniform( + array_ops.shape(acceptance_probs), dtype=x.dtype) + < acceptance_probs) + new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0) # TODO(b/65738010): This should work, but it doesn't for now. # reduced_shape = math_ops.reduced_shape(x_shape, event_dims) reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims, keep_dims=True)) accepted = array_ops.reshape(accepted, reduced_shape) - new_x = x * (1. - accepted) + new_x * accepted - new_grad = -grad_0 * (1. - accepted) - grad_1 * accepted - + accepted = math_ops.logical_or( + accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool)) + new_x = array_ops.where(accepted, new_x, x) + new_grad = -array_ops.where(accepted, grad_1, grad_0) + + # TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect + # to initial_x will propagate NaNs (see testNanFromGradsDontPropagate). This + # should be fixed. return new_x, acceptance_probs, new_log_prob, new_grad @@ -525,6 +537,7 @@ def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum, Has shape matching `initial_position`. Example: Simple quadratic potential. + ```python def potential_and_grad(position): return tf.reduce_sum(0.5 * tf.square(position)), position @@ -600,6 +613,7 @@ def leapfrog_step(step_size, position, momentum, potential_and_grad, grad, Has shape matching `position`. Example: Simple quadratic potential. + ```python def potential_and_grad(position): # Simple quadratic potential diff --git a/tensorflow/contrib/bayesflow/python/ops/layers.py b/tensorflow/contrib/bayesflow/python/ops/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..dcead38af826a12e776160bdb251ba021e6b953c --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers.py @@ -0,0 +1,37 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Probabilistic neural layers. + +See ${python/contrib.bayesflow.layers}. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'DenseVariational', + 'dense_variational', + 'default_loc_scale_fn', + 'default_mean_field_normal_fn', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..b05ce0ffc1dd55ffb029b339a846a9aa5c877620 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py @@ -0,0 +1,797 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dense Bayesian layer using KL-divergence based variational inference. + +@@DenseVariational +@@dense_variational + +@@default_loc_scale_fn +@@default_mean_field_normal_fn +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as layers_lib +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops.distributions import kullback_leibler as kl_lib +from tensorflow.python.ops.distributions import normal as normal_lib + + +__all__ = [ + "DenseVariational", + "dense_variational", + "default_loc_scale_fn", + "default_mean_field_normal_fn", +] + + +def default_loc_scale_fn( + is_singular=False, + loc_initializer=init_ops.random_normal_initializer(stddev=0.1), + untransformed_scale_initializer=init_ops.random_normal_initializer( + mean=-3., stddev=0.1), + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. + + This function produces a closure which produces `loc`, `scale` using + `tf.get_variable`. The closure accepts the following arguments: + + dtype: Type of parameter's event. + shape: Python `list`-like representing the parameter's event shape. + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` indicating if `scale is None`. Default: `False`. + loc_initializer: Initializer function for the `loc` parameters. + The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. Default value: `tf.random_normal_initializer(mean=-3., + stddev=0.1)`. This implies the softplus transformed result has mean + approximately `0.05` and std. deviation approximately `0.005`. + loc_regularizer: Regularizer function for the `loc` parameters. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. The default (`None`) is to use the `tf.get_variable` default. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. The default + (`None`) is to use the `tf.get_variable` default. + + Returns: + default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` + parameters from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates `loc`, `scale` parameters.""" + loc = add_variable_fn( + name=name + "_loc", + shape=shape, + initializer=loc_initializer, + regularizer=loc_regularizer, + constraint=loc_constraint, + dtype=dtype, + trainable=trainable) + if is_singular: + return loc, None + untransformed_scale = add_variable_fn( + name=name + "_untransformed_scale", + shape=shape, + initializer=untransformed_scale_initializer, + regularizer=untransformed_scale_regularizer, + constraint=untransformed_scale_constraint, + dtype=dtype, + trainable=trainable) + scale = (np.finfo(dtype.as_numpy_dtype).eps + + nn_ops.softplus(untransformed_scale)) + return loc, scale + return _fn + + +def default_mean_field_normal_fn( + is_singular=False, + loc_initializer=None, + untransformed_scale_initializer=None, + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Creates a function to build Normal distributions with trainable params. + + This function produces a closure which produces `tf.distributions.Normal` + parameterized by a loc` and `scale` each created using `tf.get_variable`. The + produced closure accepts the following arguments: + + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` if `True`, forces the special case limit of + `scale->0`, i.e., a `Deterministic` distribution. + loc_initializer: Initializer function for the `loc` parameters. + If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + loc_regularizer: Regularizer function for the `loc` parameters. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. + + Returns: + make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` + using from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + loc_scale_fn_ = default_loc_scale_fn( + is_singular, + loc_initializer, + untransformed_scale_initializer, + loc_regularizer, + untransformed_scale_regularizer, + loc_constraint, + untransformed_scale_constraint) + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates a batch of `Deterministic` or `Normal` distributions.""" + loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) + if scale is None: + return deterministic_lib.Deterministic(loc=loc) + return normal_lib.Normal(loc=loc, scale=scale) + return _fn + + +class DenseVariational(layers_lib.Layer): + """Densely-connected variational class. + + This layer implements the Bayesian variational inference analogue to: + `outputs = activation(matmul(inputs, kernel) + bias)` + by assuming the `kernel` and/or the `bias` are random variables. + + The layer implements a stochastic dense calculation by making a Monte Carlo + approximation of a [variational Bayesian method based on KL divergence]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + + ```none + -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw + = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw + <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's + = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] + ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } + + KL[q(W|x), p(W)] + ``` + + where `W` denotes the (independent) `kernel` and `bias` random variables, `w` + is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, + and `~=` denotes an approximation which becomes exact as `m->inf`. The above + bound is sometimes referred to as the negative Evidence Lower BOund or + negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this + layer is appropriate to use when the final loss is a negative log-likelihood. + + The Monte-Carlo sum portion is used for the feed-forward calculation of the + DNN. The KL divergence portion can be added to the final loss via: + `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + random variables (which together comprise `W`). + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + When `True`, `kernel_posterior_fn` must create an instance of + `tf.distributions.Normal`. + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + kernel: `VariationalKernelParamater` instance containing all `kernel` + related properties and `callable`s. + bias: `VariationalParameter` instance containing all `kernel` + related properties and `callable`s. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(DenseVariational, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self._units = units + self._activation = activation + self._input_spec = layers_lib.InputSpec(min_ndim=2) + self._kernel_use_local_reparameterization = ( + kernel_use_local_reparameterization) + self._kernel = VariationalKernelParameter( + kernel_posterior_fn, + kernel_posterior_tensor_fn, + kernel_prior_fn, + kernel_divergence_fn) + self._bias = VariationalParameter( + bias_posterior_fn, + bias_posterior_tensor_fn, + bias_prior_fn, + bias_divergence_fn) + + @property + def units(self): + return self._units + + @property + def activation(self): + return self._activation + + @property + def input_spec(self): + return self._input_spec + + @input_spec.setter + def input_spec(self, value): + self._input_spec = value + + @property + def kernel_use_local_reparameterization(self): + return self._kernel_use_local_reparameterization + + @property + def kernel(self): + return self._kernel + + @property + def bias(self): + return self._bias + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + in_size = input_shape.with_rank_at_least(2)[-1].value + if in_size is None: + raise ValueError("The last dimension of the inputs to `Dense` " + "should be defined. Found `None`.") + self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size}) + dtype = dtypes.as_dtype(self.dtype) + + # Must have a posterior kernel. + self.kernel.posterior = self.kernel.posterior_fn( + dtype, [in_size, self.units], "kernel_posterior", + self.trainable, self.add_variable) + + if self.kernel.prior_fn is None: + self.kernel_prior = None + else: + self.kernel.prior = self.kernel.prior_fn( + dtype, [in_size, self.units], "kernel_prior", + self.trainable, self.add_variable) + self._built_kernel_divergence = False + + if self.bias.posterior_fn is None: + self.bias.posterior = None + else: + self.bias.posterior = self.bias.posterior_fn( + dtype, [self.units], "bias_posterior", + self.trainable, self.add_variable) + + if self.bias.prior_fn is None: + self.bias.prior = None + else: + self.bias.prior = self.bias.prior_fn( + dtype, [self.units], "bias_prior", + self.trainable, self.add_variable) + self._built_bias_divergence = False + + self.built = True + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) + + outputs = self._apply_variational_kernel(inputs) + outputs = self._apply_variational_bias(outputs) + if self.activation is not None: + outputs = self.activation(outputs) # pylint: disable=not-callable + if not self._built_kernel_divergence: + self._apply_divergence(self.kernel, name="divergence_kernel") + self._built_kernel_divergence = True + if not self._built_bias_divergence: + self._apply_divergence(self.bias, name="divergence_bias") + self._built_bias_divergence = True + return outputs + + def _apply_variational_kernel(self, inputs): + if not self.kernel_use_local_reparameterization: + self.kernel.posterior_tensor = self.kernel.posterior_tensor_fn( + self.kernel.posterior) + self.kernel.posterior_affine = None + self.kernel.posterior_affine_tensor = None + return self._matmul(inputs, self.kernel.posterior_tensor) + if not isinstance(self.kernel.posterior, normal_lib.Normal): + raise TypeError("`kernel_use_local_reparameterization=True` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Normal` (saw: \"{}\").".format( + type(self.kernel.posterior).__name__)) + self.kernel.posterior_affine = normal_lib.Normal( + loc=self._matmul(inputs, self.kernel.posterior.loc), + scale=standard_ops.sqrt(self._matmul( + standard_ops.square(inputs), + standard_ops.square(self.kernel.posterior.scale)))) + self.kernel.posterior_affine_tensor = ( + self.kernel.posterior_tensor_fn(self.kernel.posterior_affine)) + self.kernel.posterior_tensor = None + return self.kernel.posterior_affine_tensor + + def _apply_variational_bias(self, inputs): + if self.bias.posterior is None: + self.bias.posterior_tensor = None + return inputs + self.bias.posterior_tensor = self.bias.posterior_tensor_fn( + self.bias.posterior) + return nn.bias_add(inputs, self.bias.posterior_tensor) + + def _apply_divergence(self, param, name): + if (param.divergence_fn is None or + param.posterior is None or + param.prior is None): + param.divergence = None + return + param.divergence = standard_ops.identity( + param.divergence_fn( + param.posterior, param.prior, param.posterior_tensor), + name=name) + self.add_loss(param.divergence) + + def _matmul(self, inputs, kernel): + if inputs.shape.ndims <= 2: + return standard_ops.matmul(inputs, kernel) + # To handle broadcasting, we must use `tensordot`. + return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]]) + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2) + if input_shape[-1].value is None: + raise ValueError( + "The innermost dimension of input_shape must be defined, " + "but saw: {}".format(input_shape)) + return input_shape[:-1].concatenate(self.units) + + +def dense_variational( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Densely-connected variational layer. + + This layer implements the Bayesian variational inference analogue to: + `outputs = activation(matmul(inputs, kernel) + bias)` + by assuming the `kernel` and/or the `bias` are random variables. + + The layer implements a stochastic dense calculation by making a Monte Carlo + approximation of a [variational Bayesian method based on KL divergence]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + + ```none + -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw + = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw + <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's + = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] + ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } + + KL[q(W|x), p(W)] + ``` + + where `W` denotes the (independent) `kernel` and `bias` random variables, `w` + is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, + and `~=` denotes an approximation which becomes exact as `m->inf`. The above + bound is sometimes referred to as the negative Evidence Lower BOund or + negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this + layer is appropriate to use when the final loss is a negative log-likelihood. + + The Monte-Carlo sum portion is used for the feed-forward calculation of the + DNN. The KL divergence portion can be added to the final loss via: + `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + random variables (which together comprise `W`). + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + When `True`, `kernel_posterior_fn` must create an instance of + `tf.distributions.Normal`. + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + """ + layer = DenseVariational( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_use_local_reparameterization=( + kernel_use_local_reparameterization), + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class NotSet(object): + """Helper to track whether a `VariationalParameter` value has been set.""" + pass + + +class VariationalParameter(object): + """Struct-like container of variational parameter properties. + + A `VariationalParameter` is intitialized with Python `callable`s which set the + value of correspondingly named members. Corresponding values have "set once" + semantics, i.e., once set to any value they are immutable. + """ + + def __init__( + self, + posterior_fn, + posterior_tensor_fn, + prior_fn, + divergence_fn): + """Creates the `VariationalParameter` struct-like object. + + Args: + posterior_fn: Python `callable` which creates a + `tf.distribution.Distribution` like object representing the posterior + distribution. See `VariationalParameter.posterior_fn` for `callable`'s + required parameters. + posterior_tensor_fn: Python `callable` which computes a `Tensor` + which represents the `posterior`. + prior_fn: Python `callable` which creates a + `tf.distribution.Distribution` like object representing the prior + distribution. See `VariationalParameter.prior_fn` for `callable`'s + required parameters. + divergence_fn: Python `callable` which computes the KL divergence from + `posterior` to `prior`. See `VariationalParameter.divergence_fn` for + required `callable`'s parameters. + """ + self._posterior_fn = posterior_fn + self._posterior = NotSet() + self._posterior_tensor_fn = posterior_tensor_fn + self._posterior_tensor = NotSet() + self._prior_fn = prior_fn + self._prior = NotSet() + self._divergence_fn = divergence_fn + self._divergence = NotSet() + self._init_helper() + + @property + def posterior_fn(self): + """`callable` which creates `tf.distributions.Distribution`-like posterior. + + The `callable` must accept the following parameters: + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Returns: + posterior_fn: The Python `callable` specified in `__init__`. + """ + return self._posterior_fn + + @property + def posterior(self): + """`tf.distributions.Distribution`-like instance representing posterior.""" + return self._posterior + + @posterior.setter + def posterior(self, value): + """One-time setter of the `posterior` distribution.""" + if not isinstance(self._posterior, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior = value + + @property + def posterior_tensor_fn(self): + """Creates `Tensor` representing the `posterior` distribution. + + The `callable` must accept the following parameters: + posterior: `tf.distributions.Distribution`-like instance. + + Returns: + posterior_tensor_fn: The Python `callable` specified in + `__init__`. + """ + return self._posterior_tensor_fn + + @property + def posterior_tensor(self): + """`Tensor` representing the `posterior` distribution.""" + return self._posterior_tensor + + @posterior_tensor.setter + def posterior_tensor(self, value): + """One-time setter of the `posterior_tensor`.""" + if not isinstance(self._posterior_tensor, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_tensor = value + + @property + def prior_fn(self): + """`callable` which creates `tf.distributions.Distribution`-like prior. + + The `callable` must accept the following parameters: + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Returns: + prior_fn: The Python `callable` specified in `__init__`. + """ + return self._prior_fn + + @property + def prior(self): + """`tf.distributions.Distribution`-like instance representing posterior.""" + return self._prior + + @prior.setter + def prior(self, value): + """One-time setter of the `prior` distribution.""" + if not isinstance(self._prior, NotSet): + raise ValueError("Cannot override already set attribute.") + self._prior = value + + @property + def divergence_fn(self): + """`callable` which computes KL-divergence `Tensor` from posterior to prior. + + The `callable` must accept the following parameters: + posterior: `tf.distributions.Distribution`-like instance. + prior: `tf.distributions.Distribution`-like instance. + posterior_tensor: `Tensor` representing value of posterior. + + Returns: + divergence_fn: The Python `callable` specified in `__init__`. + """ + return self._divergence_fn + + @property + def divergence(self): + """`Tensor` representing KL-divergence from posterior to prior.""" + return self._divergence + + @divergence.setter + def divergence(self, value): + """One-time setter of the `divergence`.""" + if not isinstance(self._divergence, NotSet): + raise ValueError("Cannot override already set attribute.") + self._divergence = value + + def _init_helper(self): + pass + + +class VariationalKernelParameter(VariationalParameter): + """Struct-like container of variational kernel properties. + + A `VariationalKernelParameter` is intitialized with Python `callable`s which + set the value of correspondingly named members. Corresponding values have "set + once" semantics, i.e., once set to any value they are immutable. + """ + + @property + def posterior_affine(self): + """`tf.distributions.Distribution` affine transformed posterior.""" + return self._posterior_affine + + @posterior_affine.setter + def posterior_affine(self, value): + """One-time setter of `posterior_affine`.""" + if not isinstance(self._posterior_affine, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_affine = value + + @property + def posterior_affine_tensor(self): + """`Tensor` representing the `posterior_affine` distribution.""" + return self._posterior_affine_tensor + + @posterior_affine_tensor.setter + def posterior_affine_tensor(self, value): + """One-time setter of the `posterior_affine_tensor`.""" + if not isinstance(self._posterior_affine_tensor, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_affine_tensor = value + + def _init_helper(self): + self._posterior_affine = NotSet() + self._posterior_affine_tensor = NotSet() diff --git a/tensorflow/contrib/bayesflow/python/ops/optimizers.py b/tensorflow/contrib/bayesflow/python/ops/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..ee32e6b5c3d9efaeaf73436638c5eea55f2cfc70 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/optimizers.py @@ -0,0 +1,34 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Probabilistic optimizer modules. + +See ${python/contrib.bayesflow.optimizers}. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.sgld_optimizer import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'SGLDOptimizer', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5d36ea7a2b51aa45cdc253992a2a58634c068987 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py @@ -0,0 +1,216 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An optimizer module for stochastic gradient Langevin dynamics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope as varscope_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops + + +class SGLDOptimizer(optimizer.Optimizer): + """An optimizer module for stochastic gradient Langevin dynamics. + + This implements the preconditioned Stochastic Gradient Langevin Dynamics + optimizer [1]. The optimization variable is regarded as a sample from the + posterior under Stochastic Gradient Langevin Dynamics with noise rescaled in + each dimension according to RMSProp [2]. + + Note: If a prior is included in the loss, it should be scaled by + `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches + in the data. I.e., it should be divided by the `num_pseudo_batches` term + described below. + + [1]: "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural + Networks." Chunyuan Li, Changyou Chen, David Carlson, Lawrence Carin. + ArXiv:1512.07666, 2015. https://arxiv.org/abs/1512.07666 + [2]: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf + + Args: + learning_rate: Scalar `float`-like `Tensor`. The base learning rate for the + optimizer. Must be tuned to the specific function being minimized. + preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential + decay rate of the rescaling of the preconditioner (RMSprop). (This is + "alpha" in [1]). Should be smaller than but nearly `1` to approximate + sampling from the posterior. (Default: `0.95`) + num_pseudo_batches: Scalar `int`-like `Tensor`. The effective number of + minibatches in the data set. Trades off noise and prior with the SGD + likelihood term. Note: Assumes the loss is taken as the mean over a + minibatch. Otherwise if the sum was taken, divide this number by the + batch size. (Default: `1`) + burnin: Scalar `int`-like `Tensor`. The number of iterations to collect + gradient statistics to update the preconditioner before starting to draw + noisy samples. (Default: `25`) + diagonal_bias: Scalar `float`-like `Tensor`. Term added to the diagonal of + the preconditioner to prevent the preconditioner from degenerating. + (Default: `1e-8`) + name: Python `str` describing ops managed by this function. + (Default: `"SGLDOptimizer"`) + variable_scope: Variable scope used for calls to `tf.get_variable`. + If `None`, a new variable scope is created using name + `ops.get_default_graph().unique_name(name or default_name)`. + + Raises: + InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in + `(0,1]`. + """ + + def __init__(self, + learning_rate, + preconditioner_decay_rate=0.95, + num_pseudo_batches=1, + burnin=25, + diagonal_bias=1e-8, + name=None, + variable_scope=None): + default_name = 'SGLDOptimizer' + with ops.name_scope(name, default_name, [ + learning_rate, preconditioner_decay_rate, num_pseudo_batches, burnin, + diagonal_bias + ]): + if variable_scope is None: + var_scope_name = ops.get_default_graph().unique_name( + name or default_name) + with varscope_ops.variable_scope(var_scope_name) as scope: + self._variable_scope = scope + else: + self._variable_scope = variable_scope + + self._preconditioner_decay_rate = ops.convert_to_tensor( + preconditioner_decay_rate, name='preconditioner_decay_rate') + self._num_pseudo_batches = ops.convert_to_tensor( + num_pseudo_batches, name='num_pseudo_batches') + self._burnin = ops.convert_to_tensor(burnin, name='burnin') + self._diagonal_bias = ops.convert_to_tensor( + diagonal_bias, name='diagonal_bias') + self._learning_rate = ops.convert_to_tensor( + learning_rate, name='learning_rate') + + with varscope_ops.variable_scope(self._variable_scope): + self._counter = varscope_ops.get_variable( + 'counter', initializer=0, trainable=False) + + self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._preconditioner_decay_rate, + message='`preconditioner_decay_rate` must be non-negative'), + check_ops.assert_less_equal( + self._preconditioner_decay_rate, + 1., + message='`preconditioner_decay_rate` must be at most 1.'), + ], self._preconditioner_decay_rate) + + self._num_pseudo_batches = control_flow_ops.with_dependencies([ + check_ops.assert_greater( + self._num_pseudo_batches, + 0, + message='`num_pseudo_batches` must be greater than zero') + ], self._num_pseudo_batches) + + self._burnin = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._burnin, message='`burnin` must be non-negative'), + check_ops.assert_integer( + self._burnin, message='`burnin` must be an integer') + ], self._burnin) + + self._diagonal_bias = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._diagonal_bias, + message='`diagonal_bias` must be non-negative') + ], self._diagonal_bias) + + super(SGLDOptimizer, self).__init__(use_locking=False, + name=name or default_name) + + def _create_slots(self, var_list): + for v in var_list: + init_rms = init_ops.ones_initializer(dtype=v.dtype) + self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(), + v.dtype, 'rms', self._name) + + def _prepare(self): + # We need to put the conversion and check here because a user will likely + # want to decay the learning rate dynamically. + self._learning_rate_tensor = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._learning_rate, message='`learning_rate` must be non-negative') + ], ops.convert_to_tensor(self._learning_rate, name='learning_rate_tensor')) + self._decay_tensor = ops.convert_to_tensor( + self._preconditioner_decay_rate, name='preconditioner_decay_rate') + + super(SGLDOptimizer, self)._prepare() + + def _apply_dense(self, grad, var): + rms = self.get_slot(var, 'rms') + + with ops.control_dependencies([ + self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, + var.dtype.base_dtype))]): + new_grad = self._apply_noisy_update(rms, grad) + + return training_ops.apply_gradient_descent( + var, + math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + new_grad, + use_locking=self._use_locking).op + + def _apply_sparse(self, grad, var): + rms = self.get_slot(var, 'rms') + + with ops.control_dependencies([ + self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, + var.dtype.base_dtype))]): + new_grad = self._apply_noisy_update(rms, grad) + + return training_ops.apply_gradient_descent( + var, + math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + new_grad, + use_locking=self._use_locking).op + + @property + def variable_scope(self): + """Variable scope of all calls to `tf.get_variable`.""" + return self._variable_scope + + def _apply_noisy_update(self, mom, grad): + # Compute and apply the gradient update following + # preconditioned Langevin dynamics + stddev = array_ops.where( + array_ops.squeeze(self._counter > self._burnin), + math_ops.cast(math_ops.rsqrt(self._learning_rate), grad.dtype), + array_ops.zeros([], grad.dtype)) + + preconditioner = math_ops.rsqrt( + mom + math_ops.cast(self._diagonal_bias, grad.dtype)) + return ( + 0.5 * preconditioner * grad * math_ops.cast(self._num_pseudo_batches, + grad.dtype) + + random_ops.random_normal(array_ops.shape(grad), 1.0, dtype=grad.dtype) * + stddev * math_ops.sqrt(preconditioner)) + + def _update_momentum(self, mom, grad, decay): + # Keep an exponentially weighted moving average of squared gradients. + # Not thread safe + return mom.assign_add((1.0 - decay) * (math_ops.square(grad) - mom)) diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 66a04d42e93331de74b6f3d41f83f071115c1097..7072f56420ac9e576b20b62c0aa67498857403a7 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -359,8 +359,8 @@ tf_custom_op_library( ], deps = [ "//tensorflow/contrib/boosted_trees/lib:example_partitioner", - "//tensorflow/contrib/boosted_trees/lib:feature-column-handlers", "//tensorflow/contrib/boosted_trees/lib:models", + "//tensorflow/contrib/boosted_trees/lib:node-stats", "//tensorflow/contrib/boosted_trees/lib:utils", "//tensorflow/contrib/boosted_trees/lib:weighted_quantiles", "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", @@ -404,10 +404,12 @@ tf_kernel_library( name = "split_handler_ops_kernels", srcs = ["kernels/split_handler_ops.cc"], deps = [ - "//tensorflow/contrib/boosted_trees/lib:feature-column-handlers", + "//tensorflow/contrib/boosted_trees/lib:node-stats", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:protos_all_cc", + "//third_party/eigen3", ], alwayslink = 1, ) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index ef8dee91b6cc05c4c3dd5eb3c81de4fb65b473e3..6ebc7d7911df878ec91701db8b75feb9a27d18a2 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -33,6 +33,8 @@ from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader as saved_model_loader from tensorflow.python.saved_model import tag_constants +_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d" + def make_custom_export_strategy(name, convert_fn, @@ -147,13 +149,12 @@ def convert_to_universal_format(dtec, sorted_feature_names, inequality_test.threshold.float_value = split.threshold elif node_type == "sparse_float_binary_split_default_left": split = gtflow_node.sparse_float_binary_split_default_left.split - node.default_direction = ( - generic_tree_model_pb2.BinaryNode.LEFT) - # TODO(nponomareva): adjust this id assignement when we allow multi- - # column sparse tensors. + node.default_direction = (generic_tree_model_pb2.BinaryNode.LEFT) feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test - inequality_test.feature_id.id.value = sorted_feature_names[feature_id] + inequality_test.feature_id.id.value = ( + _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % + (sorted_feature_names[feature_id], split.dimension_id)) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold @@ -165,7 +166,9 @@ def convert_to_universal_format(dtec, sorted_feature_names, # column sparse tensors. feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test - inequality_test.feature_id.id.value = sorted_feature_names[feature_id] + inequality_test.feature_id.id.value = ( + _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % + (sorted_feature_names[feature_id], split.dimension_id)) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold @@ -201,10 +204,14 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats, split_column = feature_names[split.feature_column] elif node_type == "sparse_float_binary_split_default_left": split = tree_node.sparse_float_binary_split_default_left.split - split_column = feature_names[split.feature_column + num_dense_floats] + split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( + feature_names[split.feature_column + num_dense_floats], + split.dimension_id) elif node_type == "sparse_float_binary_split_default_right": split = tree_node.sparse_float_binary_split_default_right.split - split_column = feature_names[split.feature_column + num_dense_floats] + split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( + feature_names[split.feature_column + num_dense_floats], + split.dimension_id) elif node_type == "categorical_id_binary_split": split = tree_node.categorical_id_binary_split split_column = feature_names[split.feature_column + num_dense_floats + diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py index 4ed18b2d34c5af47826ab1c058f5d13797593bd4..492d9ca40c5cfa84e186020605429aacc02af6a6 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the conversion code from GTFlow format to Chauffeur.""" +"""Tests for the conversion code and for feature importances export. + +Tests that cover conversion from TFBT format to a tensorflow.contrib. +decision_tree generic_tree_model format and feature importances export. +""" from __future__ import absolute_import from __future__ import division @@ -95,10 +99,31 @@ class ConvertModelTest(test_util.TensorFlowTestCase): } } } + nodes { + sparse_float_binary_split_default_right { + split { + feature_column: 1 + dimension_id:3 + threshold: -0.4 + left_id: 7 + right_id: 8 + } + } + node_metadata { + gain: 3600 + } + } + nodes { + leaf { + vector { + value: 0.36 + } + } + } nodes { leaf { vector { - value: 0.3 + value: 18 } } } @@ -108,17 +133,25 @@ class ConvertModelTest(test_util.TensorFlowTestCase): """ dtec = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge(dtec_str, dtec) - feature_columns = ["feature_b", "feature_a", "feature_d"] + feature_columns = [ + "feature_b", + "feature_a", + "feature_a_m", + "feature_d", + ] return dtec, feature_columns def testConvertModel(self): dtec, feature_columns = self._make_trees() + # Assume 2 sparse float columns, one with 1 dimension, the second one with + # 5 dimensions. # The feature columns in the order they were added. out = custom_export_strategy.convert_to_universal_format( - dtec, feature_columns, 1, 1, - 1) + dtec, feature_columns, 1, 2, 1) + # Features a and a_m are sparse float features, a_m is multidimensional. expected_tree = """ features { key: "feature_a" } + features { key: "feature_a_m" } features { key: "feature_b" } features { key: "feature_d" } model { @@ -169,7 +202,6 @@ class ConvertModelTest(test_util.TensorFlowTestCase): } } } - nodes { node_id { value: 1 @@ -196,7 +228,7 @@ class ConvertModelTest(test_util.TensorFlowTestCase): inequality_left_child_test { feature_id { id { - value: "feature_a" + value: "feature_a_0" } } threshold { @@ -259,14 +291,51 @@ class ConvertModelTest(test_util.TensorFlowTestCase): node_id { value: 6 } + binary_node { + left_child_id { + value: 7 + } + right_child_id { + value: 8 + } + default_direction: RIGHT + inequality_left_child_test { + feature_id { + id { + value: "feature_a_m_3" + } + } + threshold { + float_value: -0.4 + } + } + } + } + nodes { + node_id { + value: 7 + } leaf { vector { value { - float_value: 0.03 + float_value: 0.036 } } } } + nodes { + node_id { + value: 8 + } + leaf { + vector { + value { + float_value: 1.8 + } + } + } + } + } } submodel_id { @@ -280,12 +349,15 @@ class ConvertModelTest(test_util.TensorFlowTestCase): def testFeatureImportance(self): dtec, feature_columns = self._make_trees() feature_importances = custom_export_strategy._get_feature_importances( - dtec, feature_columns, 1, 1, 1) - self.assertItemsEqual(["feature_b", "feature_a", "feature_d"], - feature_importances.keys()) + dtec, feature_columns, 1, 2, 1) + self.assertItemsEqual( + ["feature_b", "feature_a_0", "feature_a_m_3", "feature_d"], + feature_importances.keys()) self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4) - self.assertAlmostEqual(50.0, feature_importances["feature_a"], places=4) + self.assertAlmostEqual(50.0, feature_importances["feature_a_0"], places=4) self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4) + self.assertAlmostEqual( + 360.0, feature_importances["feature_a_m_3"], places=4) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py index 2c0a3c4912b82aba88e2f8f1b97a227c894ee2ae..e9dbdb0fd784052eeb36ac1aa9342165ef2ac0a7 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston.py +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -22,7 +22,7 @@ r"""Demonstrates a regression on Boston housing data. python tensorflow/contrib/boosted_trees/examples/boston.py \ --batch_size=404 --output_dir="/tmp/boston" --depth=4 --learning_rate=0.1 \ - --num_eval_steps=1 --num_trees=500 --l2=4 \ + --num_eval_steps=1 --num_trees=500 --l2=0.001 \ --vmodule=training_ops=1 When training is done, mean squared error on eval data is reported. @@ -37,8 +37,10 @@ from __future__ import division from __future__ import print_function import argparse +import os import sys import tensorflow as tf +from tensorflow.contrib.boosted_trees.estimator_batch import custom_export_strategy from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeRegressor from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column @@ -51,22 +53,18 @@ _BOSTON_NUM_FEATURES = 13 def _get_tfbt(output_dir, feature_cols): """Configures TF Boosted Trees estimator based on flags.""" learner_config = learner_pb2.LearnerConfig() - learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate learner_config.regularization.l1 = 0.0 - # Set the regularization per instance in such a way that - # regularization for the full training data is equal to l2 flag. - learner_config.regularization.l2 = FLAGS.l2 / FLAGS.batch_size + learner_config.regularization.l2 = FLAGS.l2 learner_config.constraints.max_tree_depth = FLAGS.depth - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) # Create a TF Boosted trees regression estimator. estimator = GradientBoostedDecisionTreeRegressor( learner_config=learner_config, - # For the WHOLE_TREE strategy, set the examples_per_layer to be equal to - # batch size. + # This should be the number of examples. For large datasets it can be + # larger than the batch_size. examples_per_layer=FLAGS.batch_size, feature_columns=feature_cols, label_dimension=1, @@ -77,6 +75,14 @@ def _get_tfbt(output_dir, feature_cols): return estimator +def _convert_fn(dtec, sorted_feature_names, num_dense, num_sparse_float, + num_sparse_int, export_dir, unused_eval_result): + universal_format = custom_export_strategy.convert_to_universal_format( + dtec, sorted_feature_names, num_dense, num_sparse_float, num_sparse_int) + with tf.gfile.GFile(os.path.join(export_dir, "tree_proto"), "w") as f: + f.write(str(universal_format)) + + def _make_experiment_fn(output_dir): """Creates experiment for gradient boosted decision trees.""" (x_train, y_train), (x_test, @@ -88,21 +94,31 @@ def _make_experiment_fn(output_dir): batch_size=FLAGS.batch_size, num_epochs=None, shuffle=True) - eval_input_fn = tf.estimator.inputs.numpy_input_fn( x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) feature_columns = [ feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES) ] - + feature_spec = tf.contrib.layers.create_feature_spec_for_parsing( + feature_columns) + serving_input_fn = tf.contrib.learn.utils.build_parsing_serving_input_fn( + feature_spec) + # An export strategy that outputs the feature importance and also exports + # the internal tree representation in another format. + export_strategy = custom_export_strategy.make_custom_export_strategy( + "exports", + convert_fn=_convert_fn, + feature_columns=feature_columns, + export_input_fn=serving_input_fn) return tf.contrib.learn.Experiment( estimator=_get_tfbt(output_dir, feature_columns), train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, train_steps=None, eval_steps=FLAGS.num_eval_steps, - eval_metrics=None) + eval_metrics=None, + export_strategies=[export_strategy]) def main(unused_argv): diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index 766982b4f2023310e6046619939f83bef63b0302..f8086b0c2bb93eae6af0336bbe33fc23f8fcde22 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -63,19 +63,26 @@ const char* kPredictionsTensorName = "predictions"; void CalculateTreesToInclude( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, const std::vector& trees_to_drop, const int32 num_trees, - const bool only_finalized, std::vector* trees_to_include) { + const bool only_finalized, const bool center_bias, + std::vector* trees_to_include) { trees_to_include->reserve(num_trees - trees_to_drop.size()); int32 index = 0; // This assumes that trees_to_drop is a sorted list of tree ids. for (int32 tree = 0; tree < num_trees; ++tree) { - if ((!trees_to_drop.empty() && index < trees_to_drop.size() && - trees_to_drop[index] == tree) || - (only_finalized && config.tree_metadata_size() > 0 && - !config.tree_metadata(tree).is_finalized())) { + // Skip the tree if tree is in the list of trees_to_drop. + if (!trees_to_drop.empty() && index < trees_to_drop.size() && + trees_to_drop[index] == tree) { ++index; continue; } + // Or skip if the tree is not finalized and only_finalized is set, + // with the exception of centering bias. + if (only_finalized && !(center_bias && tree == 0) && + config.tree_metadata_size() > 0 && + !config.tree_metadata(tree).is_finalized()) { + continue; + } trees_to_include->push_back(tree); } } @@ -250,7 +257,7 @@ class GradientTreesPredictionOp : public OpKernel { CalculateTreesToInclude( ensemble_resource->decision_tree_ensemble(), dropped_trees, ensemble_resource->decision_tree_ensemble().trees_size(), - only_finalized_trees_, &trees_to_include); + only_finalized_trees_, center_bias_, &trees_to_include); // Allocate output predictions matrix. Tensor* output_predictions_t = nullptr; diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index b08028eb635385357ba13b48d88157936978b6f1..8600c8c53caa5fd4274ba6730fc764d8315d680c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -50,6 +50,7 @@ const char* const kAreBucketsReadyName = "are_buckets_ready"; const char* const kNumSparseFeaturesName = "num_sparse_features"; const char* const kSparseBucketsName = "sparse_buckets"; const char* const kSparseValuesName = "sparse_values"; +const char* const kSparseIndicesName = "sparse_indices"; const char* const kSparseStreamsStateName = "sparse_streams_state"; const char* const kSparseSummariesName = "sparse_summaries"; const char* const kSparseConfigName = "sparse_config"; @@ -85,9 +86,23 @@ std::vector GetBuckets(const int32 feature, return buckets_vector; } -void QuantizeFeatures(const string& output_name, const OpInputList& values_list, - const OpInputList& buckets_list, - OpKernelContext* const context) { +int32 GetFeatureDimension(const int32 feature_index, const int64 instance, + const OpInputList* const indices_list) { + if (indices_list != nullptr) { + // Sparse multidimensional. + return (*indices_list)[feature_index].matrix()(instance, 1); + } + // No indices, assume one-dimensional tensor. + return 0; +} + +// Allows quantization for each of multiple dimensions of a sparse feature. +void QuantizeFeatures( + const string& output_name, const OpInputList& values_list, + const OpInputList& buckets_list, + const OpInputList* const + indices_list /** Optional, provide for sparse features **/, + OpKernelContext* const context) { if (values_list.size() == 0) { return; } @@ -100,10 +115,13 @@ void QuantizeFeatures(const string& output_name, const OpInputList& values_list, const int64 num_values = values_tensor.dim_size(0); Tensor* output_t = nullptr; + // Output will have bucket id and dimension of the features for that bucket. OP_REQUIRES_OK( - context, output_list.allocate(feature_index, TensorShape({num_values}), - &output_t)); - TTypes::Vec output = output_t->vec(); + context, output_list.allocate(feature_index, + TensorShape({num_values, 2}), &output_t)); + + auto output = output_t->matrix(); + const std::vector& buckets_vector = GetBuckets(feature_index, buckets_list); auto flat_values = values_tensor.flat(); @@ -116,7 +134,11 @@ void QuantizeFeatures(const string& output_name, const OpInputList& values_list, } const int32 bucket = static_cast(bucket_iter - buckets_vector.begin()); - output(instance) = bucket; + // Bucket id. + output(instance, 0) = bucket; + // Dimension. + output(instance, 1) = + GetFeatureDimension(feature_index, instance, indices_list); } } } @@ -851,6 +873,11 @@ class QuantilesOp : public OpKernel { OP_REQUIRES_OK(context, context->input_list(kSparseValuesName, &sparse_float_feature_values_list)); + + OpInputList sparse_float_indices_list; + OP_REQUIRES_OK(context, context->input_list(kSparseIndicesName, + &sparse_float_indices_list)); + OpInputList sparse_buckets_list; OP_REQUIRES_OK( context, context->input_list(kSparseBucketsName, &sparse_buckets_list)); @@ -865,10 +892,10 @@ class QuantilesOp : public OpKernel { // Quantize the feature values QuantizeFeatures(kDenseOutputTensorName, dense_float_features_list, - dense_buckets_list, context); + dense_buckets_list, nullptr, context); QuantizeFeatures(kSparseOutputTensorName, sparse_float_feature_values_list, - sparse_buckets_list, context); + sparse_buckets_list, &sparse_float_indices_list, context); } }; diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 29635bb3c404e54f0561d9b9189270022f063cbe..18b4abd654ea3541d646a43ac901aca1a678446f 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -16,7 +16,7 @@ #include #include -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h" #include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/core/framework/device_base.h" @@ -39,6 +39,10 @@ using boosted_trees::learner::stochastic::GradientStats; using boosted_trees::learner::stochastic::NodeStats; using boosted_trees::learner::LearnerConfig_MultiClassStrategy; +namespace { +const int32 DUMMY_FEATURE_DIMENSION = -1; +} // namespace + class BaseBuildSplitOp : public OpKernel { public: explicit BaseBuildSplitOp(OpKernelConstruction* const context) @@ -128,7 +132,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* bucket_ids_t; OP_REQUIRES_OK(context, context->input("bucket_ids", &bucket_ids_t)); - const auto& bucket_ids = bucket_ids_t->vec(); + const auto& bucket_ids = bucket_ids_t->matrix(); const Tensor* gradients_t; OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); @@ -219,7 +223,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { split_info.mutable_split_node()->mutable_dense_float_binary_split(); dense_split->set_feature_column(feature_column_group_id_); dense_split->set_threshold( - bucket_boundaries(bucket_ids(best_bucket_idx))); + bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); @@ -262,7 +266,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* bucket_ids_t; OP_REQUIRES_OK(context, context->input("bucket_ids", &bucket_ids_t)); - const auto& bucket_ids = bucket_ids_t->vec(); + const auto& bucket_ids_and_dimensions = bucket_ids_t->matrix(); + + const int32 tensor_elements = partition_ids.size(); const Tensor* gradients_t; OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); @@ -273,24 +279,59 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { int class_id; ReadClassId(context, &class_id); - // Find the number of unique partitions before we allocate the output. - std::vector partition_boundaries; + // For each partition (tree node), store starting index for each dimension. + PartitionAndDimensionBoundaries partition_boundaries; + // Stores indices in partition_boundaries for those partitions that are + // not empty (have at least one dimension and a bucket apart from catch-all + // bucket of -1 bucket id and dimension 0. std::vector non_empty_partitions; - for (int i = 0; i < partition_ids.size() - 1; ++i) { + bool non_empty_partition = false; + + for (int i = 0; i < partition_ids.size(); ++i) { // Make sure the input is sorted by partition_ids; - CHECK_LE(partition_ids(i), partition_ids(i + 1)); - if (i == 0 || partition_ids(i) != partition_ids(i - 1)) { - partition_boundaries.push_back(i); - // Some partitions might only have bias feature. We don't want to split - // those so check that the partition has at least 2 buckets. - if (partition_ids(i) == partition_ids(i + 1)) { - non_empty_partitions.push_back(partition_boundaries.size() - 1); + if (i > 0) { + CHECK_LE(partition_ids(i - 1), partition_ids(i)) + << "Partition ids should be sorted. Not sorted for " << i; + } + const int32 dimension = bucket_ids_and_dimensions(i, 1); + + if (i == 0 || (partition_ids(i) != partition_ids(i - 1))) { + if (i != 0) { + // Not the first entry, so partition has changed. + if (non_empty_partition) { + // Saves the id of a previous partition in a list of non empty + // partitions, since it was non empty (had more than just a bias + // bucket -1. + non_empty_partitions.push_back(partition_boundaries.size() - 1); + } + // Add dummy dimension to signify the end for the previous dimension. + partition_boundaries.back().emplace_back(DUMMY_FEATURE_DIMENSION, i); } + // Allocate for a new partition. + partition_boundaries.emplace_back(); + // Save info about the first dimension for a new partition. + partition_boundaries.back().emplace_back(dimension, i); + + // Each partition has dummy -1 bucket with all gradients and then info + // for all other dimensions -> if we have >1 elements for a partition, + // then it is not empty. + non_empty_partition = (i < partition_ids.size() - 1) && + (partition_ids(i) == partition_ids(i + 1)); + } else if (bucket_ids_and_dimensions(i, 1) != + bucket_ids_and_dimensions(i - 1, 1)) { + // Dimension changed. + partition_boundaries.back().emplace_back(dimension, i); } } - if (partition_ids.size() > 0) { - partition_boundaries.push_back(partition_ids.size()); + if (tensor_elements > 0) { + if (non_empty_partition) { + non_empty_partitions.push_back(partition_boundaries.size() - 1); + } + // Add dummy dimension to signify the end for the previous dimension. + partition_boundaries.back().emplace_back(DUMMY_FEATURE_DIMENSION, + partition_ids.size()); } + int num_elements = non_empty_partitions.size(); Tensor* output_partition_ids_t = nullptr; OP_REQUIRES_OK(context, @@ -314,73 +355,128 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + // For each tree node that needs to be split. for (int root_idx = 0; root_idx < num_elements; ++root_idx) { + const auto& dimension_boundaries = + partition_boundaries[non_empty_partitions[root_idx]]; + float best_gain = std::numeric_limits::lowest(); - int start_index = partition_boundaries[non_empty_partitions[root_idx]]; - int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; - // First bucket ID in each partition should be the bias feature. - OP_REQUIRES(context, bucket_ids(start_index) == bias_feature_id_, - errors::InvalidArgument("Bias feature ID missing.")); + int32 best_dimension_idx = 0; + bool default_right = false; + int32 best_element_idx = 0; + + NodeStats best_right_node_stats(0); + NodeStats best_left_node_stats(0); + + // For each partition, the first bucket is dummy catch all. + int32 bias_start_index = dimension_boundaries[0].start_index; + + OP_REQUIRES( + context, + bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_, + errors::InvalidArgument("Bias feature ID missing.")); + + // Dimension for bias feature is always 0 + OP_REQUIRES( + context, bucket_ids_and_dimensions(bias_start_index, 1) == 0, + errors::InvalidArgument("Bias feature ID must be with dimension 0.")); + // For each root, we do two passes over the quantized feature buckets // accumulating gradients on one side and using the root aggregate // gradients to get the gradients for the other side. // Split gains are evaluated for each pass at every threshold and the best // split is picked. - GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); + GradientStats root_gradient_stats(*gradients_t, *hessians_t, + bias_start_index); root_gradient_stats *= normalizer_ratio; NodeStats root_stats = ComputeNodeStats(root_gradient_stats); - GradientStats present_gradient_stats; - for (int64 bucket_idx = start_index + 1; bucket_idx < end_index; - ++bucket_idx) { - present_gradient_stats += - GradientStats(*gradients_t, *hessians_t, bucket_idx); - } - present_gradient_stats *= normalizer_ratio; - int32 best_bucket_idx = 0; - NodeStats best_right_node_stats(0); - NodeStats best_left_node_stats(0); - GradientStats left_gradient_stats; - bool default_right = false; - for (int64 bucket_idx = start_index + 1; bucket_idx < end_index; - ++bucket_idx) { - GradientStats g(*gradients_t, *hessians_t, bucket_idx); - g *= normalizer_ratio; - left_gradient_stats += g; - // We have the sum of all present gradients. Use that to compute the - // backward pass gradients. - GradientStats right_gradient_stats = - present_gradient_stats - left_gradient_stats; - { - NodeStats left_stats_default_left = - ComputeNodeStats(root_gradient_stats - right_gradient_stats); - NodeStats right_stats_default_left = - ComputeNodeStats(right_gradient_stats); - if (left_stats_default_left.gain + right_stats_default_left.gain > - best_gain) { - best_gain = - left_stats_default_left.gain + right_stats_default_left.gain; - best_left_node_stats = left_stats_default_left; - best_right_node_stats = right_stats_default_left; - best_bucket_idx = bucket_idx; - default_right = false; - } + + // Iterate through dimensions. + for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { + const DimensionBoundary& dimension_and_start = dimension_boundaries[j]; + const int32 dimension_id = dimension_and_start.dimension_id; + + int start_index = dimension_and_start.start_index; + // Even for the last dimension, we always have additional dummy + // dimension that we can use to find the end index. + const int end_index = + partition_boundaries[non_empty_partitions[root_idx]][j + 1] + .start_index; + CHECK(bucket_ids_and_dimensions(start_index, 1) == + bucket_ids_and_dimensions(end_index - 1, 1)) + << "For bucket " << bucket_ids_and_dimensions(start_index, 0) + << " the dimension was " + << bucket_ids_and_dimensions(start_index, 1) << " and for " + << bucket_ids_and_dimensions(end_index - 1, 0) << " " + << bucket_ids_and_dimensions(end_index - 1, 1); + if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) { + // 0-dimension case which has a first bucket for catch all feature. + CHECK(bucket_ids_and_dimensions(start_index, 1) == 0) + << "Dimension of bias feature should be 0"; + ++start_index; } - { - NodeStats left_stats_default_right = - ComputeNodeStats(left_gradient_stats); - NodeStats right_stats_default_right = - ComputeNodeStats(root_gradient_stats - left_gradient_stats); - if (left_stats_default_right.gain + right_stats_default_right.gain > - best_gain) { - best_gain = - left_stats_default_right.gain + right_stats_default_right.gain; - best_left_node_stats = left_stats_default_right; - best_right_node_stats = right_stats_default_right; - best_bucket_idx = bucket_idx; - default_right = true; + + GradientStats present_gradient_stats; + for (int64 bucket_idx = start_index; bucket_idx < end_index; + ++bucket_idx) { + present_gradient_stats += + GradientStats(*gradients_t, *hessians_t, bucket_idx); + } + present_gradient_stats *= normalizer_ratio; + + GradientStats left_gradient_stats; + for (int64 element_idx = start_index; element_idx < end_index; + ++element_idx) { + // Check that bucket ids are sorted. + if (element_idx != start_index) { + CHECK(bucket_ids_and_dimensions(element_idx - 1, 0) < + bucket_ids_and_dimensions(element_idx, 0)) + << "Bucket ids must be sorted." + << ", problem on " << element_idx << " and dimension is " << j; + } + + GradientStats g(*gradients_t, *hessians_t, element_idx); + g *= normalizer_ratio; + left_gradient_stats += g; + // We have the sum of all present gradients. Use that to compute the + // backward pass gradients. + GradientStats right_gradient_stats = + present_gradient_stats - left_gradient_stats; + { + NodeStats left_stats_default_left = + ComputeNodeStats(root_gradient_stats - right_gradient_stats); + NodeStats right_stats_default_left = + ComputeNodeStats(right_gradient_stats); + if (left_stats_default_left.gain + right_stats_default_left.gain > + best_gain) { + best_gain = + left_stats_default_left.gain + right_stats_default_left.gain; + best_left_node_stats = left_stats_default_left; + best_right_node_stats = right_stats_default_left; + best_element_idx = element_idx; + default_right = false; + best_dimension_idx = dimension_id; + } + } + { + NodeStats left_stats_default_right = + ComputeNodeStats(left_gradient_stats); + NodeStats right_stats_default_right = + ComputeNodeStats(root_gradient_stats - left_gradient_stats); + if (left_stats_default_right.gain + right_stats_default_right.gain > + best_gain) { + best_gain = left_stats_default_right.gain + + right_stats_default_right.gain; + best_left_node_stats = left_stats_default_right; + best_right_node_stats = right_stats_default_right; + best_element_idx = element_idx; + default_right = true; + best_dimension_idx = dimension_id; + } } } } + SplitInfo split_info; boosted_trees::trees::DenseFloatBinarySplit* dense_split = nullptr; if (default_right) { @@ -393,8 +489,13 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { ->mutable_split(); } dense_split->set_feature_column(feature_column_group_id_); - dense_split->set_threshold( - bucket_boundaries(bucket_ids(best_bucket_idx))); + // Set the feature index for the best feature column. + const int64 best_dimension_id = + bucket_ids_and_dimensions(best_element_idx, 1); + const int32 best_bucket_id = + bucket_ids_and_dimensions(best_element_idx, 0); + dense_split->set_dimension_id(best_dimension_id); + dense_split->set_threshold(bucket_boundaries(best_bucket_id)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); @@ -403,11 +504,23 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = best_gain - root_stats.gain - tree_complexity_regularization_; - output_partition_ids(root_idx) = partition_ids(start_index); + output_partition_ids(root_idx) = partition_ids(bias_start_index); } } private: + struct DimensionBoundary { + DimensionBoundary(const int32 dimension_id, const int32 start_index) + : dimension_id(dimension_id), start_index(start_index) {} + + int32 dimension_id; + int32 start_index; + }; + + // For each partition, store start indices of feature column dimensions. + typedef std::vector> + PartitionAndDimensionBoundaries; + int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU), @@ -434,7 +547,7 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { const Tensor* feature_ids_t; OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t)); - const auto& feature_ids = feature_ids_t->vec(); + const auto& feature_ids = feature_ids_t->matrix(); const Tensor* gradients_t; OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); @@ -491,7 +604,7 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { int start_index = partition_boundaries[non_empty_partitions[root_idx]]; int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; // First feature ID in each partition should be the bias feature. - OP_REQUIRES(context, feature_ids(start_index) == bias_feature_id_, + OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_, errors::InvalidArgument("Bias feature ID missing.")); GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); root_gradient_stats *= normalizer_ratio; @@ -519,7 +632,7 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); equality_split->set_feature_column(feature_column_group_id_); - equality_split->set_feature_id(feature_ids(best_feature_idx)); + equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); FillLeaf(class_id, best_left_node_stats, left_child); diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index cff75e71d93cb703d87bb09a4b32439e01d70f76..a9a229c8ae0c26bba5f0a684dad7e546298577bb 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -39,13 +39,14 @@ const char* const kStampTokenName = "stamp_token"; const char* const kNextStampTokenName = "next_stamp_token"; struct PartitionKey { - PartitionKey() : partition_id(-1), feature_id(-1) {} + PartitionKey() : partition_id(-1), feature_id(-1), dimension(-1) {} - PartitionKey(int32 p, int64 f) : partition_id(p), feature_id(f) {} + PartitionKey(int32 p, int64 f, int32 d) + : partition_id(p), feature_id(f), dimension(d) {} bool operator==(const PartitionKey& other) const { - return (feature_id == other.feature_id) && - (partition_id == other.partition_id); + return (partition_id == other.partition_id) && + (dimension == other.dimension) && (feature_id == other.feature_id); } // Compare for PartitionKey. @@ -54,7 +55,11 @@ struct PartitionKey { if (a.partition_id < b.partition_id) { return true; } - if ((a.partition_id == b.partition_id) && (a.feature_id < b.feature_id)) { + if ((a.partition_id == b.partition_id) && (a.dimension < b.dimension)) { + return true; + } + if ((a.partition_id == b.partition_id) && (a.dimension == b.dimension) && + (a.feature_id < b.feature_id)) { return true; } return false; @@ -64,8 +69,11 @@ struct PartitionKey { // Tree partition defined by traversing the tree to the leaf. int32 partition_id; - // Feature Id within the feature column. + // Feature column id. int64 feature_id; + + // Dimension within feature column. + int32 dimension; }; template @@ -132,12 +140,12 @@ void SerializeScalarAccumulatorToOutput( &partition_ids_t)); auto partition_ids = partition_ids_t->vec(); + // Feature ids tensor has ids of feature columns and their dimensions. Tensor* feature_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_feature_ids", TensorShape({num_slots}), - &feature_ids_t)); - auto feature_ids = feature_ids_t->vec(); + OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids", + TensorShape({num_slots, 2}), + &feature_ids_t)); + auto feature_ids = feature_ids_t->matrix(); Tensor* gradients_t = nullptr; OP_REQUIRES_OK( @@ -155,7 +163,9 @@ void SerializeScalarAccumulatorToOutput( int i = 0; for (const auto& iter : accumulator_resource.values()) { partition_ids(i) = iter.first.partition_id; - feature_ids(i) = iter.first.feature_id; + feature_ids(i, 0) = iter.first.feature_id; + feature_ids(i, 1) = iter.first.dimension; + gradients(i) = iter.second.first; hessians(i) = iter.second.second; ++i; @@ -174,11 +184,10 @@ void SerializeTensorAccumulatorToOutput( auto partition_ids = partition_ids_t->vec(); Tensor* feature_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_feature_ids", TensorShape({num_slots}), - &feature_ids_t)); - auto feature_ids = feature_ids_t->vec(); + OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids", + TensorShape({num_slots, 2}), + &feature_ids_t)); + auto feature_ids = feature_ids_t->matrix(); TensorShape gradient_shape = accumulator_resource.gradient_shape(); int64 num_gradient_elements = gradient_shape.num_elements(); @@ -201,7 +210,9 @@ void SerializeTensorAccumulatorToOutput( int i = 0; for (const auto& iter : accumulator_resource.values()) { partition_ids(i) = iter.first.partition_id; - feature_ids(i) = iter.first.feature_id; + feature_ids(i, 0) = iter.first.feature_id; + feature_ids(i, 1) = iter.first.dimension; + for (int j = 0; j < num_gradient_elements; ++j) { gradients(i, j) = iter.second.first[j]; } @@ -220,14 +231,16 @@ void AddToScalarAccumulator( 1); const TensorShape& partition_ids_shape = partition_ids_t.shape(); const auto& partition_ids = partition_ids_t.vec(); - const auto& feature_ids = feature_ids_t.vec(); + const auto& feature_ids_and_dimensions = feature_ids_t.matrix(); const auto& gradients = gradients_t.vec(); const auto& hessians = hessians_t.vec(); int64 num_updates = partition_ids_shape.dim_size(0); auto stats_map = accumulator_resource->mutable_values(); for (int64 i = 0; i < num_updates; ++i) { - const auto key = PartitionKey(partition_ids(i), feature_ids(i)); + const auto key = + PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), + feature_ids_and_dimensions(i, 1)); auto itr = stats_map->find(key); if (itr != stats_map->end()) { itr->second.first += gradients(i); @@ -263,7 +276,7 @@ void AddToTensorAccumulator( const TensorShape& partition_ids_shape = partition_ids_t.shape(); const auto& partition_ids = partition_ids_t.vec(); - const auto& feature_ids = feature_ids_t.vec(); + const auto& feature_ids_and_dimensions = feature_ids_t.matrix(); TensorShape gradients_shape = gradients_t.shape(); const auto& gradients = gradients_t.flat_outer_dims(); TensorShape hessians_shape = hessians_t.shape(); @@ -288,7 +301,9 @@ void AddToTensorAccumulator( int64 num_updates = partition_ids_shape.dim_size(0); auto stats_map = accumulator_resource->mutable_values(); for (int64 i = 0; i < num_updates; ++i) { - const auto key = PartitionKey(partition_ids(i), feature_ids(i)); + const auto key = + PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), + feature_ids_and_dimensions(i, 1)); auto itr = stats_map->find(key); if (itr == stats_map->end()) { std::vector new_gradients(gradients_shape.num_elements()); diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 2a5c7949f2d1f68eef1714c47446907038bd7216..c77d90e243c304ec8e9a10a0b63401f9bd825c3e 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -237,6 +237,7 @@ class CenterTreeEnsembleBiasOp : public OpKernel { VLOG(1) << "Continuing to center bias, delta=" << total_delta; } else { VLOG(1) << "Done centering bias, delta=" << total_delta; + ensemble_resource->LastTreeMetadata()->set_is_finalized(true); } Tensor* continue_centering_t = nullptr; OP_REQUIRES_OK( @@ -260,7 +261,6 @@ class CenterTreeEnsembleBiasOp : public OpKernel { for (size_t idx = 0; idx < logits_dimension; ++idx) { leaf->mutable_vector()->add_value(0.0); } - ensemble_resource->LastTreeMetadata()->set_is_finalized(true); return leaf; } else if (num_trees == 1) { // Confirms that the only tree is a bias and returns its leaf. diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index 107ff0d295bee530c1711a97849fbd3c6cdb2f00..131bd48562a55a08981ac73277e93024db0d85d3 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -406,51 +406,9 @@ tf_cc_test( ) # Learner/stochastic - -cc_library( - name = "feature-column-handlers", - srcs = [ - "learner/stochastic/handlers/bias-feature-column-handler.cc", - "learner/stochastic/handlers/categorical-feature-column-handler.cc", - "learner/stochastic/handlers/dense-quantized-feature-column-handler.cc", - "learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc", - ], - hdrs = [ - "learner/stochastic/handlers/bias-feature-column-handler.h", - "learner/stochastic/handlers/categorical-feature-column-handler.h", - "learner/stochastic/handlers/dense-quantized-feature-column-handler.h", - "learner/stochastic/handlers/feature-column-handler.h", - "learner/stochastic/handlers/sparse-quantized-feature-column-handler.h", - ], - deps = [ - ":feature-split-candidate", - ":feature-stats-accumulator", - "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:protos_all_cc", - ], -) - -tf_cc_test( - name = "feature-column-handlers_test", - size = "small", - srcs = [ - "learner/stochastic/handlers/bias-feature-column-handler_test.cc", - "learner/stochastic/handlers/categorical-feature-column-handler_test.cc", - "learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc", - "learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc", - ], - deps = [ - ":feature-column-handlers", - "//tensorflow/core:tensor_testutil", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - cc_library( name = "gradient-stats", - hdrs = ["learner/stochastic/stats/gradient-stats.h"], + hdrs = ["learner/common/stats/gradient-stats.h"], deps = [ "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", @@ -459,7 +417,7 @@ cc_library( cc_library( name = "node-stats", - hdrs = ["learner/stochastic/stats/node-stats.h"], + hdrs = ["learner/common/stats/node-stats.h"], deps = [ ":gradient-stats", "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", @@ -471,7 +429,7 @@ cc_library( cc_library( name = "split-stats", - hdrs = ["learner/stochastic/stats/split-stats.h"], + hdrs = ["learner/common/stats/split-stats.h"], deps = [ ":node-stats", ], @@ -479,7 +437,7 @@ cc_library( cc_library( name = "feature-split-candidate", - hdrs = ["learner/stochastic/stats/feature-split-candidate.h"], + hdrs = ["learner/common/stats/feature-split-candidate.h"], deps = [ ":split-stats", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", @@ -489,7 +447,7 @@ cc_library( tf_cc_test( name = "node-stats_test", size = "small", - srcs = ["learner/stochastic/stats/node-stats_test.cc"], + srcs = ["learner/common/stats/node-stats_test.cc"], deps = [ ":node-stats", "//tensorflow/core:tensor_testutil", diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index 83dad7e4b3301327bcbae5203e9d9330c9e0084d..9f78ab20242800fd8af7ad049d5970fbe26ec0ea 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -110,8 +110,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): def not_active_inputs(): return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([], dtype=dtypes.int64), empty_gradients, - empty_hessians) + constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + empty_gradients, empty_hessians) def active_inputs(): """The normal flow when the handler is active.""" @@ -154,7 +154,12 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): [per_partition_hessians, filtered_hessians], 0) feature_ids = array_ops.concat( [bias_feature_ids, self._sparse_int_column.values], 0) - return partition_ids, feature_ids, filtered_gradients, filtered_hessians + # Dimension is always zero for sparse int features. + dimension_ids = array_ops.zeros_like(feature_ids, dtype=dtypes.int64) + feature_ids_and_dimensions = array_ops.stack( + [feature_ids, dimension_ids], axis=1) + return (partition_ids, feature_ids_and_dimensions, filtered_gradients, + filtered_hessians) partition_ids, feature_ids, gradients_out, hessians_out = ( control_flow_ops.cond(is_active[0], active_inputs, not_active_inputs)) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 8c0a3f0d91e0fbd6b6ca02352c8b80b8485d029d..7df514cd207c5e781f3b4abaa2020016b197669d 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -257,6 +257,7 @@ class DenseSplitHandler(InequalitySplitHandler): # Put quantile and stats accumulator flushing in the dependency path. are_splits_ready = control_flow_ops.with_dependencies( [flush_quantiles, partition_ids], are_splits_ready) + partition_ids, gains, split_infos = ( split_handler_ops.build_dense_inequality_splits( num_minibatches=num_minibatches, @@ -433,14 +434,15 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, def ready_inputs_fn(): """Branch to execute when quantiles are ready.""" quantized_feature = quantile_ops.quantiles([float_column], [], - [quantile_buckets], []) + [quantile_buckets], [], []) quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) - quantized_feature = array_ops.reshape(quantized_feature, [-1]) + quantized_feature = array_ops.squeeze(quantized_feature, axis=0) return (example_partition_ids, quantized_feature, gradients, hessians) def not_ready_inputs_fn(): - return (constant_op.constant([], dtype=dtypes.int32), constant_op.constant( - [], dtype=dtypes.int64), empty_gradients, empty_hessians) + return (constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([[]], dtype=dtypes.int64, shape=[1, 2]), + empty_gradients, empty_hessians) example_partition_ids, feature_ids, gradients, hessians = ( control_flow_ops.cond( @@ -461,10 +463,13 @@ def sparse_make_stats_update( def quantiles_ready(): """The subgraph for when the quantiles are ready.""" - quantized_feature = quantile_ops.quantiles([sparse_column_values], [], - [quantile_buckets], []) - quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) - quantized_feature = array_ops.reshape(quantized_feature, [-1]) + quantized_feature = quantile_ops.quantiles([], [sparse_column_values], [], + [quantile_buckets], + [sparse_column_indices]) + + quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64) + quantized_feature = array_ops.squeeze(quantized_feature, axis=0) + example_indices, _ = array_ops.split( sparse_column_indices, num_or_size_splits=2, axis=1) example_indices = array_ops.squeeze(example_indices, [1]) @@ -486,19 +491,25 @@ def sparse_make_stats_update( bias_feature_ids = array_ops.fill( array_ops.shape(unique_partitions), _BIAS_FEATURE_ID) bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64) + zeros = array_ops.zeros_like(bias_feature_ids) + bias_feature_ids = array_ops.stack([bias_feature_ids, zeros], axis=1) + partition_ids = array_ops.concat( [unique_partitions, filtered_partition_ids], 0) filtered_gradients = array_ops.concat( [per_partition_gradients, filtered_gradients], 0) filtered_hessians = array_ops.concat( [per_partition_hessians, filtered_hessians], 0) + bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0) + return partition_ids, bucket_ids, filtered_gradients, filtered_hessians def quantiles_not_ready(): """The subgraph for when the quantiles are not ready.""" - return (constant_op.constant([], dtype=dtypes.int32), constant_op.constant( - [], dtype=dtypes.int64), empty_gradients, empty_hessians) + return (constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + empty_gradients, empty_hessians) empty_float = constant_op.constant([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index ee16a5f838a65f20db4436eb86527518621b6d8d..54d03018d9e266beabbbabd78ebbb80cfe689c04 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -1121,6 +1121,87 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(gains), 0) self.assertEqual(len(splits), 0) + def testDegenerativeCase(self): + with self.test_session() as sess: + # One data example only, one leaf and thus one quantile bucket.The same + # situation is when all examples have the same values. This case was + # causing before a failure. + gradients = array_ops.constant([0.2]) + hessians = array_ops.constant([0.12]) + example_partitions = array_ops.constant([1], dtype=dtypes.int32) + indices = array_ops.constant([[0, 0]], dtype=dtypes.int64) + values = array_ops.constant([0.58]) + sparse_column = sparse_tensor.SparseTensor(indices, values, [1, 1]) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = ordinal_split_handler.SparseSplitHandler( + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + epsilon=0.01, + num_quantiles=2, + feature_column_group_id=0, + sparse_float_column=sparse_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([1, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(1, 2, class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([1], partitions) + self.assertAllEqual([0.0], gains) + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + split_node = split_info.split_node.sparse_float_binary_split_default_left + + self.assertEqual(0, split_node.split.feature_column) + + self.assertAllClose(0.58, split_node.split.threshold) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h similarity index 90% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h index fe22691178213094b9affcdee06af98011f85bd2..339c2e0fded10e6a7b140da62e152e2868ffd164 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h @@ -13,10 +13,10 @@ // limitations under the License. // // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" namespace tensorflow { @@ -58,4 +58,4 @@ struct FeatureSplitCandidate { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h similarity index 98% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h index dad64bf165a41bc4f32eea6b37e7afb569887a06..34e3ddb777242553d62035a51f1aec33d0f9ba54 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ #include @@ -190,4 +190,4 @@ inline GradientStats operator-(const GradientStats& a, const GradientStats& b) { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h similarity index 98% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h index 4e5f53874df2207ffa6664a33675f84ef055394b..642a183aec5c7e591579fa5ee91d45729bfb624d 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/Eigen/Eigenvalues" -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h" #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/core/framework/shape_inference.h" @@ -298,4 +298,4 @@ struct NodeStats { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc similarity index 99% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc index ecb7a04efb96248210d9af770c8377b7f6906598..f867e77d3ef0609774628b2a9c36ca52bcf2a957 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h similarity index 94% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h index f700cbced833543227de39f54c9ecbb03a7ce7c9..054ccd9a8cd0be0c48b14cca013f15677deba900 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ #include -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h" namespace tensorflow { namespace boosted_trees { @@ -81,4 +81,4 @@ struct SplitStats { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc deleted file mode 100644 index b880cf2c47989b1434f17802befb7dd7c248b36f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -void BiasFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all examples and aggregate gradient stats for each sub-root. - for (int64 example_idx = 0; example_idx < batch_size_; ++example_idx) { - auto partition_id = example_partition_ids[example_idx]; - gradient_stats_accumulator->AddStats( - slot_id_, class_id_, partition_id, kBiasFeatureId, - GradientStats(example_first_order_gradients, - example_second_order_gradients, example_idx)); - } -} - -void BiasFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - split_candidates->clear(); - split_candidates->reserve(roots.size()); - boosted_trees::trees::TreeNode tree_node; - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - const NodeStats& root_node_stats = root_stats[root_idx]; - tree_node.Clear(); - root_node_stats.FillLeaf(class_id_, tree_node.mutable_leaf()); - split_candidates->emplace_back(slot_id_, tree_node, - SplitStats(learner_config, root_node_stats)); - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h deleted file mode 100644 index 5c0f99185a63db33a391a98fa16f37bef99507c9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_H_ // NOLINT -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_H_ // NOLINT - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a bias feature column in the single class case. -// This handler is useful even if we don't introduce a bias feature because -// it allows us to aggregate stats per partition which in turn allows us -// to compute node stats for each root to split. -class BiasFeatureColumnHandler : public FeatureColumnHandler { - public: - BiasFeatureColumnHandler(const uint32 class_id, const uint32 slot_id, - const int64 batch_size) - : FeatureColumnHandler(class_id, slot_id, batch_size) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - static constexpr auto kBiasFeatureId = 0; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_H_ // NOLINT diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc deleted file mode 100644 index f4c7df7fabda1a38d7e6cca4c5c8bc81cb7551b1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 7; -const auto kSlotId = 0; -const auto kBatchSize = 4; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class BiasFeatureColumnHandlerTest : public ::testing::Test { - protected: - BiasFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 1, 3}) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - - // Create handler. - handler_.reset(new BiasFeatureColumnHandler(kClassId, kSlotId, kBatchSize)); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - std::unique_ptr handler_; -}; - -TEST_F(BiasFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition. - // Partition 0. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(-0.3f, 0.19f), - accumulator.GetStats(kSlotId, kClassId, 0, - BiasFeatureColumnHandler::kBiasFeatureId)); - // Partition 1. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(1.2f, 0.2f), - accumulator.GetStats(kSlotId, kClassId, 1, - BiasFeatureColumnHandler::kBiasFeatureId)); - // Partition 2. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 2, - BiasFeatureColumnHandler::kBiasFeatureId)); - // Partition 3. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 3, - BiasFeatureColumnHandler::kBiasFeatureId)); -} - -TEST_F(BiasFeatureColumnHandlerTest, GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 3. - // Root 0 has zero gain and root 3 has the same gain as the leaf. - const std::vector roots = {0, 3}; - const std::vector& root_stats = { - NodeStats(1), NodeStats(learner_config_, GradientStats(4.0f, 0.13f))}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect two candidate splits (one per root). - EXPECT_EQ(2, split_candidates.size()); - - // Verify first candidate for root 0, gain is expected to be the same as - // the left child since the root node gain is zero. - const SplitStats expected_split_stats0(learner_config_, root_stats[0]); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kLeaf, tree_node0.node_case()); - EXPECT_EQ(1, tree_node0.leaf().sparse_vector().index_size()); - EXPECT_EQ(kClassId, tree_node0.leaf().sparse_vector().index(0)); - EXPECT_EQ(1, tree_node0.leaf().sparse_vector().value_size()); - EXPECT_EQ(root_stats[0].weight_contribution[0], - tree_node0.leaf().sparse_vector().value(0)); - - // Verify second candidate for root 3, gain is expected to be zero as - // the left child gain is equal to the parent gain. - const SplitStats expected_split_stats1(learner_config_, root_stats[1]); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kLeaf, tree_node1.node_case()); - EXPECT_EQ(1, tree_node1.leaf().sparse_vector().index_size()); - EXPECT_EQ(kClassId, tree_node1.leaf().sparse_vector().index(0)); - EXPECT_EQ(1, tree_node1.leaf().sparse_vector().value_size()); - EXPECT_EQ(root_stats[1].weight_contribution[0], - tree_node1.leaf().sparse_vector().value(0)); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc deleted file mode 100644 index 3a6c409f846c9ca0bd6b5101e96447642b949978..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h" - -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -namespace { - -// Creates a categorical Id split node without assigning children. -boosted_trees::trees::TreeNode CreateCategoricalIdNode( - const int32 feature_column, const int32 id) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_categorical_id_binary_split(); - split->set_feature_column(feature_column); - split->set_feature_id(id); - return split_node; -} - -} // namespace - -void CategoricalFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all rows and aggregate gradient stats for each feature id. - const int64 num_rows = indices_.dimension(0); - for (int64 row_idx = 0; row_idx < num_rows; ++row_idx) { - auto example_idx = indices_(row_idx, 0); - auto feature_id = values_(row_idx); - const GradientStats norm_gradient_stats(example_first_order_gradients, - example_second_order_gradients, - example_idx); - auto partition_id = example_partition_ids[example_idx]; - gradient_stats_accumulator->AddStats(slot_id_, class_id_, partition_id, - feature_id, norm_gradient_stats); - } -} - -void CategoricalFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - // Build a reverse lookup of partition id to root idx. - std::unordered_map partition_id_to_root_idx; - partition_id_to_root_idx.reserve(roots.size()); - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - partition_id_to_root_idx[roots[root_idx]] = root_idx; - } - - // Initialize split candidates. - split_candidates->clear(); - if (!roots.empty()) { - FeatureSplitCandidate empty_candidate( - root_stats[0].weight_contribution.size()); - split_candidates->resize(roots.size(), empty_candidate); - } - for (auto& split_candidate : *split_candidates) { - split_candidate.split_stats.gain = std::numeric_limits::lowest(); - } - - // Evaluate split candidates for every root as each is a separate - // logical partition over the examples. - // Then for each root, we evaluate every feature id as an equality split - // and pick the highest split gain. - for (const auto& entry : - gradient_stats_accumulator.GetFeatureStats(slot_id_)) { - DCHECK_EQ(entry.first.class_id, class_id_); - - // Get partition id and root node stats. - const int32 partition_id = entry.first.partition_id; - auto root_idx_it = partition_id_to_root_idx.find(partition_id); - if (root_idx_it == partition_id_to_root_idx.end()) { - // Inactive partition. - continue; - } - size_t root_idx = root_idx_it->second; - const NodeStats& root_node_stats = root_stats[root_idx]; - - // Get gradient stats. - const auto& left_gradient_stats = entry.second; - auto right_gradient_stats = - root_node_stats.gradient_stats - left_gradient_stats; - - // Get node stats. - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate and update best split candidate for the - // current root if needed. - FeatureSplitCandidate split_candidate( - slot_id_, - CreateCategoricalIdNode(feature_column_, entry.first.feature_id), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - FeatureSplitCandidate& best_split_candidate = (*split_candidates)[root_idx]; - if (TF_PREDICT_FALSE(best_split_candidate.tree_node.node_case() == - boosted_trees::trees::TreeNode::NODE_NOT_SET)) { - // Always replace candidates with no node set. - best_split_candidate = std::move(split_candidate); - } else if (TF_PREDICT_FALSE(split_candidate.split_stats.gain == - best_split_candidate.split_stats.gain)) { - // Tie break on feature id. - auto best_split_feature_id = - best_split_candidate.tree_node.categorical_id_binary_split() - .feature_id(); - if (entry.first.feature_id < best_split_feature_id) { - best_split_candidate = std::move(split_candidate); - } - } else if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h deleted file mode 100644 index ef964ba716c6adf9cf9c291cca5f52f7a6efe26f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_CATEGORICAL_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_CATEGORICAL_FEATURE_COLUMN_HANDLER_H_ - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a categorical feature column in the single class case. -class CategoricalFeatureColumnHandler : public FeatureColumnHandler { - public: - CategoricalFeatureColumnHandler(const int32 class_id, const int32 slot_id, - const int64 batch_size, - const int32 feature_column, - TTypes::ConstMatrix indices, - TTypes::ConstVec values) - : FeatureColumnHandler(class_id, slot_id, batch_size), - feature_column_(feature_column), - indices_(indices), - values_(values) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - protected: - const int32 feature_column_; - TTypes::ConstMatrix indices_; - TTypes::ConstVec values_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_CATEGORICAL_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc deleted file mode 100644 index ea82b3f086d24dc1f9ceb4783abd68be35b34b00..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 7; -const auto kSlotId = 0; -const auto kBatchSize = 4; -const auto kFeatureColumn = 3; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class CategoricalFeatureColumnHandlerTest : public ::testing::Test { - protected: - // The data looks like the following: - // Example | Gradients | Partition | Feature Id | - // i0 | (0.2, 0.12) | 0 | 1,2 | - // i1 | (-0.5, 0.07) | 0 | | - // i2 | (1.2, 0.2) | 0 | 2 | - // i3 | (4.0, 0.13) | 1 | 0 | - CategoricalFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 0, 1}), - indices_(test::AsTensor({0, 0, 0, 1, 2, 0, 3, 0}, {4, 2})), - values_(test::AsTensor({1, 2, 2, 0}, {4})) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - // Create handler. - handler_.reset(new CategoricalFeatureColumnHandler( - kClassId, kSlotId, kBatchSize, kFeatureColumn, indices_.matrix(), - values_.vec())); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - const Tensor indices_; - const Tensor values_; - std::unique_ptr handler_; -}; - -TEST_F(CategoricalFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition and feature. - // Partition 0, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 0, 0)); - // Partition 0, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.2f, 0.12f), - accumulator.GetStats(kSlotId, kClassId, 0, 1)); - // Partition 0, Feature 2. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.2f + 1.2f, 0.12f + 0.2f), - accumulator.GetStats(kSlotId, kClassId, 0, 2)); - - // Partition 1, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 1, 0)); - // Partition 1, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 1)); - // Partition 1, Feature 2. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 2)); -} - -TEST_F(CategoricalFeatureColumnHandlerTest, GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 1. - // The root stats are derived from the per-partition total gradient stats. - const std::vector roots = {0, 1, 5}; - const std::vector& root_stats = { - NodeStats(learner_config_, GradientStats(0.9f, 0.39f)), - NodeStats(learner_config_, GradientStats(4.0f, 0.13f)), NodeStats(1)}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect three candidate splits (one per root). - EXPECT_EQ(3, split_candidates.size()); - - // Verify candidate for root 0, the best split occurs when we route - // example i0, i2 left and i1 right. - const NodeStats expected_left_node0(learner_config_, - GradientStats(0.2f + 1.2f, 0.12f + 0.2f)); - const NodeStats expected_right_node0( - learner_config_, - root_stats[0].gradient_stats - expected_left_node0.gradient_stats); - const SplitStats expected_split_stats0(learner_config_, root_stats[0], - expected_left_node0, - expected_right_node0); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ( - boosted_trees::trees::TreeNode::kCategoricalIdBinarySplitFieldNumber, - tree_node0.node_case()); - const auto& split0 = tree_node0.categorical_id_binary_split(); - EXPECT_EQ(2, split0.feature_id()); - EXPECT_EQ(kFeatureColumn, split0.feature_column()); - - // Verify candidate for root 1, there's only one active feature here - // so zero gain is expected. - const NodeStats expected_left_node1(learner_config_, - root_stats[1].gradient_stats); - const NodeStats expected_right_node1(learner_config_, GradientStats(0, 0)); - const SplitStats expected_split_stats1(learner_config_, root_stats[1], - expected_left_node1, - expected_right_node1); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ( - boosted_trees::trees::TreeNode::kCategoricalIdBinarySplitFieldNumber, - tree_node1.node_case()); - const auto& split1 = tree_node1.categorical_id_binary_split(); - EXPECT_EQ(0, split1.feature_id()); - EXPECT_EQ(kFeatureColumn, split1.feature_column()); - - // Verify there are no candidate splits for root 5. - const auto& tree_node2 = split_candidates[2].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::NODE_NOT_SET, - tree_node2.node_case()); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc deleted file mode 100644 index ca7bb71e7d0b0fc945ee29092b1e36022d4c0943..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -namespace { - -// Creates a dense split node without assigning children. -boosted_trees::trees::TreeNode CreateDenseSplitNode(const int32 feature_column, - const float threshold) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_dense_float_binary_split(); - split->set_feature_column(feature_column); - split->set_threshold(threshold); - return split_node; -} - -} // namespace - -void DenseQuantizedFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all examples and aggregate gradient stats for each partition - // and quantized feature bucket. - for (int64 example_idx = 0; example_idx < batch_size_; ++example_idx) { - auto partition_id = example_partition_ids[example_idx]; - auto feature_id = dense_quantized_values_(example_idx); - gradient_stats_accumulator->AddStats( - slot_id_, class_id_, partition_id, feature_id, - GradientStats(example_first_order_gradients, - example_second_order_gradients, example_idx)); - } -} - -void DenseQuantizedFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - // Evaluate split candidates for every root as each is a separate - // logical partition over the examples. - // Then for each root, we do a forward-only pass over the quantized - // feature buckets accumulating gradients from left to right. - // Split gains are evaluated at every threshold and the best split is picked. - split_candidates->clear(); - split_candidates->reserve(roots.size()); - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - // Get partition Id and root node stats. - const int32 partition_id = roots[root_idx]; - const NodeStats& root_node_stats = root_stats[root_idx]; - - // Forward left to right pass over quantiles. - GradientStats left_gradient_stats; - GradientStats right_gradient_stats(root_node_stats.gradient_stats); - FeatureSplitCandidate best_split_candidate( - root_node_stats.weight_contribution.size()); - best_split_candidate.split_stats.gain = - std::numeric_limits::lowest(); - for (int bucket_id = 0; bucket_id < dense_quantiles_.size(); ++bucket_id) { - // Get gradient stats. - auto gradient_stats = gradient_stats_accumulator.GetStats( - slot_id_, class_id_, partition_id, bucket_id); - if (gradient_stats.IsZero()) { - continue; - } - - // Update gradient stats. - left_gradient_stats += gradient_stats; - right_gradient_stats = - root_node_stats.gradient_stats - left_gradient_stats; - - // Get node stats - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate. - const float threshold = dense_quantiles_(bucket_id); - FeatureSplitCandidate split_candidate( - slot_id_, CreateDenseSplitNode(dense_feature_column_, threshold), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } - - // Add best candidate for partition. - split_candidates->push_back(std::move(best_split_candidate)); - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h deleted file mode 100644 index 0f3858e4d8c406e9ec3ae7079b241e94ef4aa35c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_DENSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_DENSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a dense quantized feature column in the single class case. -class DenseQuantizedFeatureColumnHandler : public FeatureColumnHandler { - public: - DenseQuantizedFeatureColumnHandler( - const int32 class_id, const int32 slot_id, const int64 batch_size, - const int32 dense_feature_column, TTypes::ConstVec dense_quantiles, - TTypes::ConstVec dense_quantized_values) - : FeatureColumnHandler(class_id, slot_id, batch_size), - dense_feature_column_(dense_feature_column), - dense_quantiles_(dense_quantiles), - dense_quantized_values_(dense_quantized_values) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - protected: - const int32 dense_feature_column_; - TTypes::ConstVec dense_quantiles_; - TTypes::ConstVec dense_quantized_values_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_DENSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc deleted file mode 100644 index 1bc9d733ad3090f1cfc9547644061f54d7d2c8c6..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 1; -const auto kSlotId = 0; -const auto kBatchSize = 4; -const auto kFeatureColumn = 2; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class DenseQuantizedFeatureColumnHandlerTest : public ::testing::Test { - protected: - // The data looks like the following: - // Example | Gradients | Partition | Dense Quantile | - // i0 | (0.2, 0.12) | 0 | 1 | - // i1 | (-0.5, 0.07) | 0 | 1 | - // i2 | (1.2, 0.2) | 0 | 0 | - // i3 | (4.0, 0.13) | 1 | 1 | - DenseQuantizedFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 0, 1}), - dense_quantiles_(test::AsTensor({0.3f, 0.52f}, {2})), - dense_quantized_values_(test::AsTensor({1, 1, 0, 1}, {4})) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - // Create handler. - handler_.reset(new DenseQuantizedFeatureColumnHandler( - kClassId, kSlotId, kBatchSize, kFeatureColumn, - dense_quantiles_.vec(), dense_quantized_values_.vec())); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - const Tensor dense_quantiles_; - const Tensor dense_quantized_values_; - std::unique_ptr handler_; -}; - -TEST_F(DenseQuantizedFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition and feature. - // Partition 0, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(1.2f, 0.2f), - accumulator.GetStats(kSlotId, kClassId, 0, 0)); - // Partition 0, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(-0.3f, 0.19f), - accumulator.GetStats(kSlotId, kClassId, 0, 1)); - // Partition 1, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 0)); - // Partition 1, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 1, 1)); -} - -TEST_F(DenseQuantizedFeatureColumnHandlerTest, GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 1. - // The root stats are derived from the per-partition total gradient stats. - const std::vector roots = {0, 1, 5}; - const std::vector& root_stats = { - NodeStats(learner_config_, GradientStats(0.9f, 0.39f)), - NodeStats(learner_config_, GradientStats(4.0f, 0.13f)), NodeStats(1)}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect three candidate splits (one per root). - EXPECT_EQ(3, split_candidates.size()); - - // Verify candidate for root 0, the best split occurs when we route - // example i2 left and i0, i1 right. - const NodeStats expected_left_node0(learner_config_, - GradientStats(1.2f, 0.2f)); - const NodeStats expected_right_node0( - learner_config_, - root_stats[0].gradient_stats - expected_left_node0.gradient_stats); - const SplitStats expected_split_stats0(learner_config_, root_stats[0], - expected_left_node0, - expected_right_node0); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kDenseFloatBinarySplit, - tree_node0.node_case()); - const auto& split0 = tree_node0.dense_float_binary_split(); - EXPECT_FLOAT_EQ(dense_quantiles_.vec()(0), split0.threshold()); - EXPECT_EQ(kFeatureColumn, split0.feature_column()); - - // Verify candidate for root 1, there's only one active bucket here - // so zero gain is expected. - const NodeStats expected_left_node1(learner_config_, - root_stats[1].gradient_stats); - const NodeStats expected_right_node1(learner_config_, GradientStats(0, 0)); - const SplitStats expected_split_stats1(learner_config_, root_stats[1], - expected_left_node1, - expected_right_node1); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kDenseFloatBinarySplit, - tree_node1.node_case()); - const auto& split1 = tree_node1.dense_float_binary_split(); - EXPECT_FLOAT_EQ(dense_quantiles_.vec()(1), split1.threshold()); - EXPECT_EQ(kFeatureColumn, split1.feature_column()); - - // Verify there are no candidate splits for root 5. - const auto& tree_node2 = split_candidates[2].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::NODE_NOT_SET, - tree_node2.node_case()); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h deleted file mode 100644 index 8bd2092f9609cb684b89f70cab35a92789fb39a4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_FEATURE_COLUMN_HANDLER_H_ - -#include -#include "tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h" -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h" -#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_types.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler interface for feature columns. Each feature column type may -// have its own handler which encapsulates the logic of aggregating gradient -// stats as well as generating split candidates for each partition. -// Handlers can be stateful and must be thread compatible. -class FeatureColumnHandler { - public: - FeatureColumnHandler(const int32 class_id, const int32 slot_id, - const int64 batch_size) - : class_id_(class_id), slot_id_(slot_id), batch_size_(batch_size) {} - - virtual ~FeatureColumnHandler() {} - FeatureColumnHandler(const FeatureColumnHandler& other) = delete; - FeatureColumnHandler& operator=(const FeatureColumnHandler& other) = delete; - - // Aggregates example gradient stats for the feature column. - virtual void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const = 0; - - // Generates feature column split candidates for the specified roots. - virtual void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const = 0; - - // Accessors. - int32 class_id() const { return class_id_; } - int32 slot_id() const { return slot_id_; } - int64 batch_size() const { return batch_size_; } - - protected: - // The class Id. - const int32 class_id_; - - // The slod Id for use as a unique Id across all feature columns. - const int32 slot_id_; - - // Size of the batch of examples. - const int64 batch_size_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc deleted file mode 100644 index a0e9efbbc5030e8c2e25fafab98271337a2e582a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -namespace { - -// Creates a sparse default right split node without assigning children. -boosted_trees::trees::TreeNode CreateSparseSplitNodeDefaultRight( - int32 feature_column, float threshold) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_sparse_float_binary_split_default_right() - ->mutable_split(); - split->set_feature_column(feature_column); - split->set_threshold(threshold); - return split_node; -} - -// Creates a sparse default left split node without assigning children. -boosted_trees::trees::TreeNode CreateSparseSplitNodeDefaultLeft( - int32 feature_column, float threshold) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_sparse_float_binary_split_default_left() - ->mutable_split(); - split->set_feature_column(feature_column); - split->set_threshold(threshold); - return split_node; -} - -} // namespace - -void SparseQuantizedFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all rows and aggregate gradient stats for each partition - // and quantized feature bucket. - const int64 num_rows = sparse_indices_.dimension(0); - for (int64 row_idx = 0; row_idx < num_rows; ++row_idx) { - auto example_idx = sparse_indices_(row_idx, 0); - auto partition_id = example_partition_ids[example_idx]; - auto feature_id = sparse_quantized_values_(row_idx); - gradient_stats_accumulator->AddStats( - slot_id_, class_id_, partition_id, feature_id, - GradientStats(example_first_order_gradients, - example_second_order_gradients, example_idx)); - } -} - -void SparseQuantizedFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - // Evaluate split candidates for every root as each is a separate - // logical partition over the examples. - // Then for each root, we do both a forward left to right pass and a backward - // right to left pass over the quantized feature buckets accumulating - // gradients on one side and using the root aggregate gradients to get the - // gradients for the other side. Split gains are evaluated for each pass at - // every threshold and the best split is picked. - split_candidates->clear(); - split_candidates->reserve(roots.size()); - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - // Get partition Id and root node stats. - const int32 partition_id = roots[root_idx]; - const NodeStats& root_node_stats = root_stats[root_idx]; - - // Forward pass with right default direction. - GradientStats left_gradient_stats; - GradientStats right_gradient_stats(root_node_stats.gradient_stats); - FeatureSplitCandidate best_split_candidate( - root_node_stats.weight_contribution.size()); - best_split_candidate.split_stats.gain = - std::numeric_limits::lowest(); - for (int bucket_id = 0; bucket_id < sparse_quantiles_.size(); ++bucket_id) { - // Get gradient stats. - auto gradient_stats = gradient_stats_accumulator.GetStats( - slot_id_, class_id_, partition_id, bucket_id); - if (gradient_stats.IsZero()) { - continue; - } - - // Update gradient stats. - left_gradient_stats += gradient_stats; - right_gradient_stats = - root_node_stats.gradient_stats - left_gradient_stats; - - // Get node stats - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate. - const float threshold = sparse_quantiles_(bucket_id); - FeatureSplitCandidate split_candidate( - slot_id_, - CreateSparseSplitNodeDefaultRight(sparse_feature_column_, threshold), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } - - // Determine if we need a backward pass by checking if the residual gradient - // after forward aggregation is almost the same as the aggregated gradient. - // for the current root. This helps avoid unnecessary computation as well - // as consistency due to floating point precision. - if (!right_gradient_stats.IsAlmostZero()) { - // Backward pass with left default direction. - right_gradient_stats = GradientStats(); - left_gradient_stats = root_node_stats.gradient_stats; - for (int bucket_id = sparse_quantiles_.size() - 1; bucket_id > 0; - --bucket_id) { - // Get gradient stats. - auto gradient_stats = gradient_stats_accumulator.GetStats( - slot_id_, class_id_, partition_id, bucket_id); - if (gradient_stats.IsZero()) { - continue; - } - - // Update gradient stats. - right_gradient_stats += gradient_stats; - left_gradient_stats = root_node_stats.gradient_stats - gradient_stats; - - // Get node stats - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate. - const float threshold = sparse_quantiles_(bucket_id - 1); - FeatureSplitCandidate split_candidate( - slot_id_, - CreateSparseSplitNodeDefaultLeft(sparse_feature_column_, threshold), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } - } - - // Add best candidate for partition. - split_candidates->push_back(std::move(best_split_candidate)); - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h deleted file mode 100644 index eb63e705471a65e8448bda38b2e31eb971d5c1bb..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_SPARSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_SPARSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a sparse quantized feature column in the single class case. -class SparseQuantizedFeatureColumnHandler : public FeatureColumnHandler { - public: - SparseQuantizedFeatureColumnHandler( - const int32 class_id, const int32 slot_id, const int64 batch_size, - const int32 sparse_feature_column, - TTypes::ConstVec sparse_quantiles, - TTypes::ConstMatrix sparse_indices, - TTypes::ConstVec sparse_quantized_values) - : FeatureColumnHandler(class_id, slot_id, batch_size), - sparse_feature_column_(sparse_feature_column), - sparse_quantiles_(sparse_quantiles), - sparse_indices_(sparse_indices), - sparse_quantized_values_(sparse_quantized_values) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - protected: - const int32 sparse_feature_column_; - TTypes::ConstVec sparse_quantiles_; - TTypes::ConstMatrix sparse_indices_; - TTypes::ConstVec sparse_quantized_values_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_SPARSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc deleted file mode 100644 index 643d936ad23850e601bc5518d69c8637011f53c0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 3; -const auto kSlotId = 0; -const auto kBatchSize = 4; -const auto kFeatureColumn = 4; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class SparseQuantizedFeatureColumnHandlerTest : public ::testing::Test { - protected: - // The data looks like the following: - // Example | Gradients | Partition | Sparse Quantile | - // i0 | (0.2, 0.12) | 0 | 1 | - // i1 | (-0.5, 0.07) | 0 | N/A | - // i2 | (1.2, 0.2) | 0 | 0 | - // i3 | (4.0, 0.13) | 1 | 1 | - SparseQuantizedFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 0, 1}), - sparse_quantiles_(test::AsTensor({0.3f, 0.52f}, {2})), - sparse_indices_(test::AsTensor({0, 0, 2, 0, 3, 0}, {3, 2})), - sparse_quantized_values_(test::AsTensor({1, 0, 1}, {3})) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - // Create handler. - handler_.reset(new SparseQuantizedFeatureColumnHandler( - kClassId, kSlotId, kBatchSize, kFeatureColumn, - sparse_quantiles_.vec(), sparse_indices_.matrix(), - sparse_quantized_values_.vec())); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - const Tensor sparse_quantiles_; - const Tensor sparse_indices_; - const Tensor sparse_quantized_values_; - std::unique_ptr handler_; -}; - -TEST_F(SparseQuantizedFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition and feature. - // Partition 0, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(1.2f, 0.2f), - accumulator.GetStats(kSlotId, kClassId, 0, 0)); - // Partition 0, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.2f, 0.12f), - accumulator.GetStats(kSlotId, kClassId, 0, 1)); - // Partition 1, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 0)); - // Partition 1, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 1, 1)); -} - -TEST_F(SparseQuantizedFeatureColumnHandlerTest, - GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 1. - // The root stats are derived from the per-partition total gradient stats. - const std::vector roots = {0, 1, 9}; - const std::vector& root_stats = { - NodeStats(learner_config_, GradientStats(0.9f, 0.39f)), - NodeStats(learner_config_, GradientStats(4.0f, 0.13f)), NodeStats(1)}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect three candidate splits (one per root). - EXPECT_EQ(3, split_candidates.size()); - - // Verify candidate for root 0, the best split occurs when we route - // example i0 and i2 to the left and i1 to the right (by default direction). - const NodeStats expected_left_node0(learner_config_, - GradientStats(0.2f + 1.2f, 0.12f + 0.2f)); - const NodeStats expected_right_node0( - learner_config_, - root_stats[0].gradient_stats - expected_left_node0.gradient_stats); - const SplitStats expected_split_stats0(learner_config_, root_stats[0], - expected_left_node0, - expected_right_node0); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kSparseFloatBinarySplitDefaultRight, - tree_node0.node_case()); - const auto& split0 = - tree_node0.sparse_float_binary_split_default_right().split(); - EXPECT_FLOAT_EQ(sparse_quantiles_.vec()(1), split0.threshold()); - EXPECT_EQ(kFeatureColumn, split0.feature_column()); - - // Verify candidate for root 1, there's only one active bucket here - // so zero gain is expected. - const NodeStats expected_left_node1(learner_config_, - root_stats[1].gradient_stats); - const NodeStats expected_right_node1(learner_config_, GradientStats(0, 0)); - const SplitStats expected_split_stats1(learner_config_, root_stats[1], - expected_left_node1, - expected_right_node1); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kSparseFloatBinarySplitDefaultRight, - tree_node1.node_case()); - const auto& split1 = - tree_node1.sparse_float_binary_split_default_right().split(); - EXPECT_FLOAT_EQ(sparse_quantiles_.vec()(1), split1.threshold()); - EXPECT_EQ(kFeatureColumn, split1.feature_column()); - - // Verify there are no candidate splits for root 9. - const auto& tree_node2 = split_candidates[2].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::NODE_NOT_SET, - tree_node2.node_case()); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h index 5e316538cefed30b2867252c9ebc4754216db329..70037d5bd8f446bdbbfcc468edb8a76c05e4fab7 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h @@ -33,9 +33,9 @@ template mutable_leaf(); // Split on first column - split_node->set_feature_id(0); + split_node->set_dimension_id(0); split_node->set_threshold(2.0f); // Both instances have this feature value. @@ -199,7 +199,7 @@ TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); // Split on second column - split_node->set_feature_id(1); + split_node->set_dimension_id(1); split_node->set_threshold(5.0f); // First instance does not have it (default right), second does have it. @@ -208,7 +208,7 @@ TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); // Split on third column - split_node->set_feature_id(2); + split_node->set_dimension_id(2); split_node->set_threshold(3.0f); example_it = example_iterable.begin(); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/example.h b/tensorflow/contrib/boosted_trees/lib/utils/example.h index e388cf332c3ff327f79ea57e3a0bccbbaa1b5e45..54f60e1dee49a4a40b84fcc6e042fac1858aa187 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/example.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/example.h @@ -63,7 +63,7 @@ class SparseFloatFeatureColumn { public: void Reserve(const int32 size) { if (!single_dimensional_) { - mutlidimensional_values.Reserve(size); + multidimensional_values.Reserve(size); } } @@ -76,7 +76,7 @@ class SparseFloatFeatureColumn { DCHECK_EQ(0, feature_idx); single_value_ = value; } else { - mutlidimensional_values.Add(feature_idx, value); + multidimensional_values.Add(feature_idx, value); } initialized_ = true; } @@ -84,7 +84,7 @@ class SparseFloatFeatureColumn { void Clear() { single_dimensional_ = false; initialized_ = false; - mutlidimensional_values.Clear(); + multidimensional_values.Clear(); } OptionalValue operator[](int feature_idx) const { @@ -94,7 +94,7 @@ class SparseFloatFeatureColumn { if (single_dimensional_) { return OptionalValue(single_value_); } else { - return mutlidimensional_values[feature_idx]; + return multidimensional_values[feature_idx]; } } @@ -102,7 +102,7 @@ class SparseFloatFeatureColumn { bool single_dimensional_; bool initialized_; T single_value_; - SparseMultidimensionalValues mutlidimensional_values; + SparseMultidimensionalValues multidimensional_values; }; // Holds data for one example and enables lookup by feature column. diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc index bc0a93db8c39abf737d11682088233e2fd88e868..ccee9530b6897924453461c13b1238402c0f6cfa 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc @@ -96,6 +96,10 @@ class IndicesRowIterator return (row_idx_ != other.row_idx_); } + bool operator<(const IndicesRowIterator& other) const { + return (row_idx_ < other.row_idx_); + } + bool operator==(const IndicesRowIterator& other) const { QCHECK_EQ(iter_, other.iter_); return (row_idx_ == other.row_idx_); diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc index 82b8e8c1c272ca415b5841f5ba9433e00173f8fa..d66f645f62aba84261337eb37d6e3204930f8f15 100644 --- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc @@ -36,7 +36,7 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) { c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim, reduce_dim ? learner_config.num_classes() - 1 : learner_config.num_classes())}); - c->set_output(1, {c->Vector(InferenceContext::kUnknownDim)}); + c->set_output(1, {c->UnknownShape()}); return Status::OK(); } diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc index 4ca73ef6e3301aadda48d5c971c31b57b7925614..1fa70bafddb0c94f47d006d5694bea941edaddf9 100644 --- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -268,6 +268,7 @@ REGISTER_OP("Quantiles") .Input("sparse_values: num_sparse_features * float") .Input("dense_buckets: num_dense_features * float") .Input("sparse_buckets: num_sparse_features * float") + .Input("sparse_indices: num_sparse_features * int64") .Output("dense_quantiles: num_dense_features * int32") .Output("sparse_quantiles: num_sparse_features * int32") .Doc(R"doc( @@ -280,10 +281,13 @@ dense_values: List of rank 1 tensors containing the dense values. sparse_values: List of rank 1 tensors containing the sparse feature values. dense_buckets: Quantile summary for each of the dense float tensor. sparse_buckets: Quantile summary for each of the sparse feature float tensor. -dense_quantiles: Rank 1 tensors representing associated quantiles for each of -dense float tensors. -sparse_quantiles: Rank 1 tensors representing associated quantiles for each of -the sparse feature tensors. +sparse_indices: List of rank 2 tensors with indices for sparse float +tensors. +dense_quantiles: Rank 2 tensors representing associated quantiles for each of +dense float tensors and the dimension. +sparse_quantiles: Rank 2 tensors representing associated quantiles for each of +the sparse feature tensors for each of sparse feature dimensions: +[quantile id, dimension id]. )doc"); REGISTER_OP("BucketizeWithInputBoundaries") diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 07cfd413bbd389053ff52ca65693445ef28e8ede..0d27ddaf3a1d540efee268c2bcca217077ff5871 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -47,9 +47,7 @@ REGISTER_OP("BuildDenseInequalitySplits") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); ShapeHandle bucket_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(bucket_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -71,7 +69,7 @@ Find the split that has the best gain for the accumulated stats. num_minibatches: A scalar, the number of times per example gradients & hessians were accumulated. The stats are divided by this to get per example stats. partition_ids: A rank 1 tensor of partition IDs. -bucket_ids: A rank 1 tensor of buckets IDs. +bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. @@ -108,9 +106,7 @@ REGISTER_OP("BuildSparseInequalitySplits") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); ShapeHandle bucket_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(bucket_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -127,12 +123,13 @@ REGISTER_OP("BuildSparseInequalitySplits") return Status::OK(); }) .Doc(R"doc( -Find the split that has the best gain for the accumulated stats. +Find the split that has the best gain for the accumulated stats for a particular +feature column. num_minibatches: A scalar, the number of times per example gradients & hessians were accumulated. The stats are divided by this to get per example stats. -partition_ids: A rank 1 tensor of partition IDs. -bucket_ids: A rank 1 tensor of buckets IDs. +partition_ids: A rank 2 tensor of partition IDs for each dimension of feature column. +bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. @@ -168,9 +165,7 @@ REGISTER_OP("BuildCategoricalEqualitySplits") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); ShapeHandle bucket_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(bucket_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -190,7 +185,7 @@ Find the split that has the best gain for the accumulated stats. num_minibatches: A scalar, the number of times per example gradients & hessians were accumulated. The stats are divided by this to get per example stats. partition_ids: A rank 1 tensor of partition IDs. -feature_ids: A rank 1 tensor of feature IDs. +feature_ids: A rank 2 tensor of feature IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits diff --git a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc index f988755de021034fc0d33529286dd3b508d746ed..0354f7853cbedf22d0a299273b4dbd225b3121ab 100644 --- a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc @@ -73,9 +73,7 @@ REGISTER_OP("StatsAccumulatorScalarAdd") 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; TF_RETURN_IF_ERROR(c->WithRank( - c->input(num_resource_handles * 2 + i + 1), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + c->input(num_resource_handles * 2 + i + 1), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRank( c->input(num_resource_handles * 3 + i + 1), 1, &gradients_shape)); @@ -96,11 +94,11 @@ stamp_token: Stamp token for Read/Write operations. Any operation with a mismatching token will be dropped. stats_accumulator_handles: A list of handles to the stats accumulator. partition_ids: A list of vectors of partition_ids. -feature_ids: A list of vectors of feature_ids. +feature_ids: Rank 2 tensor of feature id and feature dimension ids. gradients: A list of vectors of gradients for each slot in - . + . hessians: A list of vectors of hessians for each slot in - . + . )doc"); REGISTER_OP("StatsAccumulatorScalarFlush") @@ -119,7 +117,7 @@ REGISTER_OP("StatsAccumulatorScalarFlush") TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); c->set_output(0, c->Scalar()); c->set_output(1, c->Vector(c->UnknownDim())); - c->set_output(2, c->Vector(c->UnknownDim())); + c->set_output(2, c->UnknownShape()); c->set_output(3, c->Vector(c->UnknownDim())); c->set_output(4, c->Vector(c->UnknownDim())); return Status::OK(); @@ -134,7 +132,7 @@ next_stamp_token: Stamp token for the next iteration. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A vector of gradients, with a value for each slot in . output_hessians: A vector of hessians, with a value for each slot @@ -161,9 +159,7 @@ REGISTER_OP("StatsAccumulatorScalarDeserialize") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -183,9 +179,11 @@ stamp_token: Stamp token for Read/Write operations. num_updates: Number of times stats were added to this accumulator since last flush. partition_ids: A vector of partition_ids. -feature_ids: A vector of feature_ids. -gradients: A vector of gradients for each slot in . -hessians: A vector of hessians for each slot in . +feature_ids: Rank 2 tensor of feature id and feature dimension ids. +gradients: A vector of gradients for each slot in . +hessians: A vector of hessians for each slot in )doc"); REGISTER_OP("StatsAccumulatorScalarSerialize") @@ -204,7 +202,7 @@ REGISTER_OP("StatsAccumulatorScalarSerialize") // num_updates c->set_output(1, c->Scalar()); c->set_output(2, c->Vector(c->UnknownDim())); - c->set_output(3, c->Vector(c->UnknownDim())); + c->set_output(3, c->UnknownShape()); c->set_output(4, c->Vector(c->UnknownDim())); c->set_output(5, c->Vector(c->UnknownDim())); return Status::OK(); @@ -217,7 +215,7 @@ stamp_token: The current stamp token for the resource. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A vector of gradients, with a value for each slot in . output_hessians: A vector of hessians, with a value for each slot @@ -293,9 +291,7 @@ REGISTER_OP("StatsAccumulatorTensorAdd") 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; TF_RETURN_IF_ERROR(c->WithRank( - c->input(num_resource_handles * 2 + i + 1), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + c->input(num_resource_handles * 2 + i + 1), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast( c->input(num_resource_handles * 3 + i + 1), 2, &gradients_shape)); @@ -316,11 +312,11 @@ stats_accumulator_handles: A list of handles to the stats accumulator. stamp_token: Stamp token for Read/Write operations. Any operation with a mismatching token will be dropped. partition_ids: A list of vectors of partition_ids. -feature_ids: A list of vectors of feature_ids. +feature_ids: Rank 2 tensor of feature id and feature dimension ids. gradients: A list of vectors of gradients for each slot in - . + . hessians: A list of vectors of hessians for each slot in - . + . )doc"); REGISTER_OP("StatsAccumulatorTensorFlush") @@ -340,7 +336,7 @@ REGISTER_OP("StatsAccumulatorTensorFlush") // num_updates c->set_output(0, c->Scalar()); c->set_output(1, c->Vector(c->UnknownDim())); - c->set_output(2, c->Vector(c->UnknownDim())); + c->set_output(2, c->UnknownShape()); c->set_output(3, c->UnknownShape()); c->set_output(4, c->UnknownShape()); return Status::OK(); @@ -355,11 +351,11 @@ next_stamp_token: Stamp token to be used for the next iteration. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids: A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A tensor of gradients, first dimension matches slots - in . + in . output_hessians: A tensor of hessians, first dimension matches slots - in . + in >. )doc"); REGISTER_OP("StatsAccumulatorTensorDeserialize") @@ -382,9 +378,7 @@ REGISTER_OP("StatsAccumulatorTensorDeserialize") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(5), 2, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -405,9 +399,11 @@ stamp_token: Stamp token for Read/Write operations. num_updates: Number of times stats were added to this accumulator since last flush. partition_ids: A vector of partition_ids. -feature_ids: A vector of feature_ids. -gradients: A vector of gradients for each slot in . -hessians: A vector of hessians for each slot in . +feature_ids: Rank 2 tensor of feature id and feature dimension ids. +gradients: A vector of gradients for each slot in +hessians: A vector of hessians for each slot in . )doc"); REGISTER_OP("StatsAccumulatorTensorSerialize") @@ -426,7 +422,7 @@ REGISTER_OP("StatsAccumulatorTensorSerialize") // num_updates c->set_output(1, c->Scalar()); c->set_output(2, c->Vector(c->UnknownDim())); - c->set_output(3, c->Vector(c->UnknownDim())); + c->set_output(3, c->UnknownShape()); c->set_output(4, c->UnknownShape()); c->set_output(5, c->UnknownShape()); return Status::OK(); @@ -440,11 +436,11 @@ stamp_token: Stamp token for Read/Write operations. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids: A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A tensor of gradients, first dimension matches slots - in . + in . output_hessians: A tensor of hessians, first dimension matches slots - in . + in . )doc"); REGISTER_OP("StatsAccumulatorTensorMakeSummary") @@ -458,18 +454,20 @@ REGISTER_OP("StatsAccumulatorTensorMakeSummary") .Output("output_hessians: float") .Doc(R"doc( Summarizes the stats by summing the that are for the same -. +. partition_ids: A vector of partition_ids. -feature_ids: A vector of feature_ids. -gradients: A vector of gradients for each slot in . -hessians: A vector of hessians for each slot in . +feature_ids: Rank 2 tensor of feature id and feature dimension ids. +gradients: A vector of gradients for each slot in . +hessians: A vector of hessians for each slot in . output_partition_ids: A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: A rank2 tensor of feature_ids and dimensions for the slots. output_gradients: A tensor of gradients, first dimension matches slots - in . + in . output_hessians: A tensor of hessians, first dimension matches slots - in . + in . )doc"); } // namespace boosted_trees } // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index f14abf45a517ad7c4c6d7bb1ab88b7a1d47d6fb6..fc570c1083d01a65760a456c109dad93afd9f62a 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -53,9 +53,9 @@ message DenseFloatBinarySplit { // Float feature column and split threshold describing // the rule feature <= threshold. int32 feature_column = 1; - // If feature column is multivalent, this holds the index of the feature for - // the split. Defaults to 0. - int32 feature_id = 5; + // If feature column is multivalent, this holds the index of the dimensiong + // for the split. Defaults to 0. + int32 dimension_id = 5; float threshold = 2; // Node children indexing into a contiguous diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index 79802922ca1b59789069a0249cee163cdd3f607a..c1acf351603dd80c2d14c7ee0a5b4c89706bc1bf 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -75,7 +75,7 @@ def _append_multi_values_to_dense_leaf(leaf, w): leaf.vector.value.append(x) -def _set_float_split(split, feat_col, thresh, l_id, r_id): +def _set_float_split(split, feat_col, thresh, l_id, r_id, feature_dim_id=None): """Helper method for building tree float splits. Sets split feature column, threshold and children. @@ -86,11 +86,14 @@ def _set_float_split(split, feat_col, thresh, l_id, r_id): thresh: threshold to split on forming rule x <= thresh. l_id: left child Id. r_id: right child Id. + feature_dim_id: dimension of the feature column to be used in the split. """ split.feature_column = feat_col split.threshold = thresh split.left_id = l_id split.right_id = r_id + if feature_dim_id is not None: + split.dimension_id = feature_dim_id def _set_categorical_id_split(split, feat_col, feat_id, l_id, r_id): @@ -116,12 +119,12 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): def setUp(self): """Sets up the prediction tests. - Create a batch of two examples having one dense float, two sparse float and - one sparse int features. + Create a batch of two examples having one dense float, two sparse float + single valued, one sparse float multidimensionl and one sparse int features. The data looks like the following: - | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | - | 0 | 7 | -3 | | 9,1 | - | 1 | -2 | | 4 | | + | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | SparseM + | 0 | 7 | -3 | | 9,1 | __, 5.0 + | 1 | -2 | | 4 | | 3, ___ """ super(PredictionOpsTest, self).setUp() self._dense_float_tensor = np.array([[7.0], [-2.0]]) @@ -131,6 +134,11 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self._sparse_float_indices2 = np.array([[1, 0]]) self._sparse_float_values2 = np.array([4.0]) self._sparse_float_shape2 = np.array([2, 1]) + # Multi dimensional sparse float + self._sparse_float_indices_m = np.array([[0, 1], [1, 0]]) + self._sparse_float_values_m = np.array([5.0, 3.0]) + self._sparse_float_shape_m = np.array([2, 2]) + self._sparse_int_indices1 = np.array([[0, 0], [0, 1]]) self._sparse_int_values1 = np.array([9, 1]) self._sparse_int_shape1 = np.array([2, 2]) @@ -287,6 +295,94 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) + def testFullEnsembleWithMultidimensionalSparseSingleClass(self): + with self.test_session(): + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + # Bias tree. + tree1 = tree_ensemble_config.trees.add() + tree_ensemble_config.tree_metadata.add().is_finalized = True + _append_to_leaf(tree1.nodes.add().leaf, 0, -0.4) + + # Depth 3 tree. + tree2 = tree_ensemble_config.trees.add() + tree_ensemble_config.tree_metadata.add().is_finalized = True + # Use feature column 2 (sparse multidimensional), split on first value + # node 0. + _set_float_split( + tree2.nodes.add().sparse_float_binary_split_default_right.split, + 2, + 7.0, + 1, + 2, + feature_dim_id=0) + # Leafs split on second dimension of sparse multidimensional feature. + # Node 1. + _set_float_split( + tree2.nodes.add().sparse_float_binary_split_default_left.split, + 2, + 4.5, + 3, + 4, + feature_dim_id=1) + # Node 2. + _set_float_split( + tree2.nodes.add().sparse_float_binary_split_default_right.split, + 2, + 9, + 5, + 6, + feature_dim_id=1) + + # Node 3. + _append_to_leaf(tree2.nodes.add().leaf, 0, 0.6) + # Node 4. + _append_to_leaf(tree2.nodes.add().leaf, 0, 1.3) + + # Node 5. + _append_to_leaf(tree2.nodes.add().leaf, 0, -0.1) + # Node 6. + _append_to_leaf(tree2.nodes.add().leaf, 0, 0.8) + + tree_ensemble_config.tree_weights.append(1.0) + tree_ensemble_config.tree_weights.append(1.0) + + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="full_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2, + self._sparse_float_indices_m + ], [ + self._sparse_float_values1, self._sparse_float_values2, + self._sparse_float_values_m + ], [ + self._sparse_float_shape1, self._sparse_float_shape2, + self._sparse_float_shape_m + ], [self._sparse_int_indices1], [self._sparse_int_values1], + [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=True) + + # The first example will get bias -0.4 from first tree and + # leaf 5 payload of -0.1 hence -0.5, the second example will + # get the same bias -0.4 and leaf 3 payload (0.6) hence 0.2 + self.assertAllClose([[-0.5], [0.2]], result.eval()) + + # Empty dropout. + self.assertAllEqual([[], []], dropout_info.eval()) + def testExcludeNonFinalTree(self): with self.test_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() @@ -322,7 +418,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE - result, dropout_info = self._get_predictions( tree_ensemble_handle, learner_config=learner_config.SerializeToString(), @@ -370,7 +465,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER - result, dropout_info = self._get_predictions( tree_ensemble_handle, learner_config=learner_config.SerializeToString(), @@ -420,7 +514,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Prepare learner config. learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, dropout_info = self._get_predictions( tree_ensemble_handle, learner_config=learner_config.SerializeToString(), diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py index 1513c11c33d538dedabe10e4411bdd1373b16c7f..888d5c57ed33446c8b6f18d2d1e393647613d132 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py @@ -48,15 +48,16 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): def testBasicQuantileBuckets(self): """Sets up the quantile summary op test as follows. - Create a batch of 6 examples having a dense and sparse features. + Create a batch of 6 examples having a dense and sparse features. SparseM is + a sparse multi-dimensional (multivalent) feature. The data looks like this - | Instance | instance weights | Dense 0 | Sparse 0 - | 0 | 10 | 1 | - | 1 | 1 | 2 | 2 - | 2 | 1 | 3 | 3 - | 3 | 1 | 4 | 4 - | 4 | 1 | 4 | 5 - | 5 | 1 | 5 | 6 + | Instance | instance weights | Dense 0 | Sparse 0 | SparseM + | 0 | 10 | 1 | | | | + | 1 | 1 | 2 | 2 | 2 | | + | 2 | 1 | 3 | 3 | 3 | | + | 3 | 1 | 4 | 4 | | 4 | + | 4 | 1 | 4 | 5 | | 5 | + | 5 | 1 | 5 | 6 | | 6 | """ dense_float_tensor_0 = constant_op.constant( @@ -66,20 +67,29 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): sparse_values_0 = constant_op.constant( [2, 3, 4, 5, 6], dtype=dtypes.float32) sparse_shape_0 = constant_op.constant([6, 1], dtype=dtypes.int64) + # Multi-dimensional feature that should have the same quantiles as Sparse 0. + sparse_indices_m = constant_op.constant( + [[1, 1], [2, 0], [3, 1], [4, 1], [5, 1]], dtype=dtypes.int64) + sparse_values_m = constant_op.constant( + [2, 3, 4, 5, 6], dtype=dtypes.float32) + sparse_shape_m = constant_op.constant([6, 2], dtype=dtypes.int64) + example_weights = constant_op.constant( [10, 1, 1, 1, 1, 1], dtype=dtypes.float32) with self.test_session(): config = self._gen_config(0.33, 3) dense_buckets, sparse_buckets = quantile_ops.quantile_buckets( - [dense_float_tensor_0], [sparse_indices_0], [sparse_values_0], - [sparse_shape_0], + [dense_float_tensor_0], [sparse_indices_0, sparse_indices_m], + [sparse_values_0, sparse_values_m], [sparse_shape_0, sparse_shape_m], example_weights=example_weights, dense_config=[config], - sparse_config=[config]) + sparse_config=[config, config]) self.assertAllEqual([1, 3, 5], dense_buckets[0].eval()) self.assertAllEqual([2, 4, 6.], sparse_buckets[0].eval()) + # Multidimensional sparse. + self.assertAllEqual([2, 4, 6.], sparse_buckets[1].eval()) def testStreamingQuantileBucketsWithVaryingBatch(self): """Sets up the quantile summary op test as follows. @@ -214,10 +224,10 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() sparse_indices_0 = constant_op.constant( - [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0]], dtype=dtypes.int64) + [[1, 0], [2, 1], [3, 0], [4, 2], [5, 0]], dtype=dtypes.int64) sparse_values_0 = constant_op.constant( [2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtypes.float32) - sparse_shape_0 = constant_op.constant([6, 1], dtype=dtypes.int64) + sparse_shape_0 = constant_op.constant([6, 3], dtype=dtypes.int64) example_weights = constant_op.constant( [10, 1, 1, 1, 1, 1], dtype=dtypes.float32, shape=[6, 1]) update = accumulator.add_summary( @@ -349,19 +359,21 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): def setUp(self): """Sets up the quantile op tests. - Create a batch of 4 examples having 2 dense and 3 sparse features. + Create a batch of 4 examples having 2 dense and 4 sparse features. + Forth sparse feature is multivalent (3 dimensional) The data looks like this - | Instance | Dense 0 | Dense 1 | Sparse 0 | Sparse 1 | Sparse 2 - | 0 | -0.1 | -1 | -2 | 0.1 | - | 1 | 0.4 | -15 | 5.5 | | 2 - | 2 | 3.2 | 18 | 16 | 3 | - | 3 | 190 | 1000 | 17.5 | -3 | 4 + | Instance | Dense 0 | Dense 1 | Sparse 0 | Sparse 1 |Sparse 2| SparseM + | 0 | -0.1 | -1 | -2 | 0.1 | |_ ,1,_ + | 1 | 0.4 | -15 | 5.5 | | 2 |2 ,_,_ + | 2 | 3.2 | 18 | 16 | 3 | |__,_,_ + | 3 | 190 | 1000 | 17.5 | -3 | 4 |1 ,8,1 Quantiles are: Dense 0: (-inf,0.4], (0.4,5], (5, 190] Dense 1: (-inf, -9], (-9,15], (15, 1000) Sparse 0: (-inf, 5], (5,16], (16, 100] Sparse 1: (-inf, 2], (2, 5] Sparse 2: (-inf, 100] + SparseM: (-inf, 1], (1,2], (2,1000] """ super(QuantilesOpTest, self).setUp() self._dense_float_tensor_0 = constant_op.constant( @@ -369,18 +381,26 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): self._dense_float_tensor_1 = constant_op.constant( [[-1], [-15], [18], [1000]], dtype=dtypes.float32) # Sparse feature 0 - self._sparse_indices_0 = constant_op.constant([[0, 0], [1, 0], [2, 0], - [3, 0]]) + self._sparse_indices_0 = constant_op.constant( + [[0, 0], [1, 0], [2, 0], [3, 0]], dtype=dtypes.int64) self._sparse_values_0 = constant_op.constant([-2, 5.5, 16, 17.5]) self._sparse_shape_0 = constant_op.constant([4, 1]) # Sprase feature 1 - self._sparse_indices_1 = constant_op.constant([[0, 0], [2, 0], [3, 0]]) + self._sparse_indices_1 = constant_op.constant( + [[0, 0], [2, 0], [3, 0]], dtype=dtypes.int64) self._sparse_values_1 = constant_op.constant([0.1, 3, -3]) self._sparse_shape_1 = constant_op.constant([4, 1]) # Sprase feature 2 - self._sparse_indices_2 = constant_op.constant([[1, 0], [3, 0]]) + self._sparse_indices_2 = constant_op.constant( + [[1, 0], [3, 0]], dtype=dtypes.int64) self._sparse_values_2 = constant_op.constant([2, 4], dtype=dtypes.float32) self._sparse_shape_2 = constant_op.constant([4, 1]) + # Sprase feature M + self._sparse_indices_m = constant_op.constant( + [[0, 1], [1, 0], [3, 0], [3, 1], [3, 2]], dtype=dtypes.int64) + self._sparse_values_m = constant_op.constant( + [1, 2, 1, 8, 1], dtype=dtypes.float32) + self._sparse_shape_m = constant_op.constant([4, 1]) # Quantiles self._dense_thresholds_0 = [0.4, 5, 190] self._dense_thresholds_1 = [-9, 15, 1000] @@ -388,52 +408,76 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): self._sparse_thresholds_0 = [5, 16, 100] self._sparse_thresholds_1 = [2, 5] self._sparse_thresholds_2 = [100] + self._sparse_thresholds_m = [1, 2, 1000] def testDenseFeaturesOnly(self): with self.test_session(): dense_quantiles, _ = quantile_ops.quantiles( [self._dense_float_tensor_0, self._dense_float_tensor_1], [], - [self._dense_thresholds_0, self._dense_thresholds_1], []) + [self._dense_thresholds_0, self._dense_thresholds_1], [], []) # Dense feature 0 - self.assertAllEqual([0, 0, 1, 2], dense_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [0, 0], [1, 0], [2, 0]], + dense_quantiles[0].eval()) # Dense feature 1 - self.assertAllEqual([1, 0, 2, 2], dense_quantiles[1].eval()) + self.assertAllEqual([[1, 0], [0, 0], [2, 0], [2, 0]], + dense_quantiles[1].eval()) def testSparseFeaturesOnly(self): with self.test_session(): - _, sparse_quantiles = quantile_ops.quantiles( - [], - [self._sparse_values_0, self._sparse_values_1, self._sparse_values_2], - [], [self._sparse_thresholds_0, self._sparse_thresholds_1, - self._sparse_thresholds_2]) - + _, sparse_quantiles = quantile_ops.quantiles([], [ + self._sparse_values_0, self._sparse_values_1, self._sparse_values_2, + self._sparse_values_m + ], [], [ + self._sparse_thresholds_0, self._sparse_thresholds_1, + self._sparse_thresholds_2, self._sparse_thresholds_m + ], [ + self._sparse_indices_0, self._sparse_indices_1, + self._sparse_indices_2, self._sparse_indices_m + ]) + + self.assertAllEqual(4, len(sparse_quantiles)) # Sparse feature 0 - self.assertAllEqual([0, 1, 1, 2], sparse_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [1, 0], [1, 0], [2, 0]], + sparse_quantiles[0].eval()) # Sparse feature 1 - self.assertAllEqual([0, 1, 0], sparse_quantiles[1].eval()) + self.assertAllEqual([[0, 0], [1, 0], [0, 0]], sparse_quantiles[1].eval()) # Sparse feature 2 - self.assertAllEqual([0, 0], sparse_quantiles[2].eval()) + self.assertAllEqual([[0, 0], [0, 0]], sparse_quantiles[2].eval()) + # Multidimensional feature. + self.assertAllEqual([[0, 1], [1, 0], [0, 0], [2, 1], [0, 2]], + sparse_quantiles[3].eval()) def testDenseAndSparseFeatures(self): with self.test_session(): dense_quantiles, sparse_quantiles = quantile_ops.quantiles( - [self._dense_float_tensor_0, self._dense_float_tensor_1], - [self._sparse_values_0, self._sparse_values_1, self._sparse_values_2], - [self._dense_thresholds_0, self._dense_thresholds_1], - [self._sparse_thresholds_0, self._sparse_thresholds_1, - self._sparse_thresholds_2]) + [self._dense_float_tensor_0, self._dense_float_tensor_1], [ + self._sparse_values_0, self._sparse_values_1, + self._sparse_values_2, self._sparse_values_m + ], [self._dense_thresholds_0, self._dense_thresholds_1], [ + self._sparse_thresholds_0, self._sparse_thresholds_1, + self._sparse_thresholds_2, self._sparse_thresholds_m + ], [ + self._sparse_indices_0, self._sparse_indices_1, + self._sparse_indices_2, self._sparse_indices_m + ]) # Dense feature 0 - self.assertAllEqual([0, 0, 1, 2], dense_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [0, 0], [1, 0], [2, 0]], + dense_quantiles[0].eval()) # Dense feature 1 - self.assertAllEqual([1, 0, 2, 2], dense_quantiles[1].eval()) + self.assertAllEqual([[1, 0], [0, 0], [2, 0], [2, 0]], + dense_quantiles[1].eval()) # Sparse feature 0 - self.assertAllEqual([0, 1, 1, 2], sparse_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [1, 0], [1, 0], [2, 0]], + sparse_quantiles[0].eval()) # Sparse feature 1 - self.assertAllEqual([0, 1, 0], sparse_quantiles[1].eval()) + self.assertAllEqual([[0, 0], [1, 0], [0, 0]], sparse_quantiles[1].eval()) # Sparse feature 2 - self.assertAllEqual([0, 0], sparse_quantiles[2].eval()) + self.assertAllEqual([[0, 0], [0, 0]], sparse_quantiles[2].eval()) + # Multidimensional feature. + self.assertAllEqual([[0, 1], [1, 0], [0, 0], [2, 1], [0, 2]], + sparse_quantiles[3].eval()) def testBucketizeWithInputBoundaries(self): with self.test_session(): diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index edf088b5fa28d3e465d4e3d8ea7cf6745d48a91f..28834ef55bf8e1f32cc8f2380a4be3bf3824d8e1 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -38,7 +38,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): # (-0.3, 0.19) | 0 | 1 | # (4.0, 0.13) | 1 | 1 | partition_ids = array_ops.constant([0, 0, 1], dtype=dtypes.int32) - bucket_ids = array_ops.constant([0, 1, 1], dtype=dtypes.int64) + bucket_ids = array_ops.constant( + [[0, 0], [1, 0], [1, 0]], dtype=dtypes.int64) gradients = array_ops.constant([2.4, -0.6, 8.0]) hessians = array_ops.constant([0.4, 0.38, 0.26]) bucket_boundaries = [0.3, 0.52] @@ -109,7 +110,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): """Tests split handler op.""" with self.test_session() as sess: partition_ids = array_ops.constant([0, 0, 1], dtype=dtypes.int32) - bucket_ids = array_ops.constant([0, 1, 1], dtype=dtypes.int64) + bucket_ids = array_ops.constant( + [[0, 0], [1, 0], [1, 0]], dtype=dtypes.int64) gradients = array_ops.constant([[2.4, 3.0], [-0.6, 0.1], [8.0, 1.0]]) hessians = array_ops.constant([[[0.4, 1], [1, 1]], [[0.38, 1], [1, 1]], [[0.26, 1], [1, 1]]]) @@ -149,7 +151,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): """Tests empty inputs op.""" with self.test_session() as sess: partition_ids = array_ops.constant([], dtype=dtypes.int32) - bucket_ids = array_ops.constant([], dtype=dtypes.int64) + bucket_ids = array_ops.constant([[]], dtype=dtypes.int64) gradients = array_ops.constant([]) hessians = array_ops.constant([]) bucket_boundaries = [0.3, 0.52] @@ -185,7 +187,11 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): # (4.0, 0.13) | 1 | -1 | # (4.0, 0.13) | 1 | 1 | partition_ids = array_ops.constant([0, 0, 0, 1, 1], dtype=dtypes.int32) + # We have only 1 dimension in our sparse feature column. bucket_ids = array_ops.constant([-1, 0, 1, -1, 1], dtype=dtypes.int64) + dimension_ids = array_ops.constant([0, 0, 0, 0, 0], dtype=dtypes.int64) + bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1) + gradients = array_ops.constant([1.8, 2.4, 0.4, 8.0, 8.0]) hessians = array_ops.constant([0.78, 0.4, 0.24, 0.26, 0.26]) bucket_boundaries = array_ops.constant([0.3, 0.52]) @@ -207,6 +213,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) partitions, gains, splits = (sess.run([partitions, gains, splits])) self.assertAllEqual([0, 1], partitions) + self.assertEqual(2, len(splits)) # Check the split on partition 0. # -(0.2 + 1.2) / (0.12 + 0.2 + 2) expected_left_weight = -0.603448275862069 @@ -232,6 +239,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([expected_right_weight], right_child.value) self.assertEqual(0, split_node.split.feature_column) + # Sparse is one dimensional. + self.assertEqual(0, split_node.split.dimension_id) self.assertAllClose(0.52, split_node.split.threshold) @@ -253,14 +262,149 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([expected_right_weight], right_child.value) self.assertEqual(0, split_node.split.feature_column) + # Sparse is one dimensional. + self.assertEqual(0, split_node.split.dimension_id) self.assertAllClose(0.52, split_node.split.threshold) + def testMakeSparseSplitAllEmptyDimensions(self): + """Tests split handler op when all dimensions have only bias bucket id.""" + with self.test_session() as sess: + # The data looks like the following after dividing by number of steps (2). + # Gradients | Partition | Dimension | bucket ID | + # (0.9, 0.39) | 0 | 0 | -1 | + # (4.0, 0.13) | 1 | 0 | -1 | + partition_ids = array_ops.constant([0, 1], dtype=dtypes.int32) + # We have only 1 dimension in our sparse feature column. + bucket_ids = array_ops.constant([[-1, 0], [-1, 0]], dtype=dtypes.int64) + gradients = array_ops.constant([1.8, 8.0]) + hessians = array_ops.constant([0.78, 0.26]) + bucket_boundaries = array_ops.constant([0.3, 0.52]) + partitions, gains, splits = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=2, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + bucket_boundaries=bucket_boundaries, + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + feature_column_group_id=0, + bias_feature_id=-1, + class_id=-1, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + partitions, gains, splits = (sess.run([partitions, gains, splits])) + self.assertEqual(0, len(partitions)) + self.assertEqual(0, len(splits)) + + def testMakeSparseMultidimensionalSplit(self): + """Tests split handler op.""" + with self.test_session() as sess: + # Num of steps is 2. + # The feature column is three dimensional. + # First dimension has bias bucket only, the second has bias bucket and + # two valid buckets, the third has just one bias bucket and one valid + # bucket. + # Gradients | Partition | Dimension | bucket ID | + # (0.9, 0.39) | 0 | 0 | -1 | + # (1.2, 0.2) | 0 | 1 | 0 | + # (0.2, 0.12) | 0 | 1 | 2 | + # (0.1, 0.1) | 0 | 2 | 3 | + # Now second node - nothing interesting there, just one dimension. + # Second node has the same bucket ids for all dimensions. + # (4.0, 0.13) | 1 | 0 | -1 | + # (4.0, 0.13) | 1 | 2 | 3 | + + # Tree node ids. + partition_ids = array_ops.constant([0, 0, 0, 0, 1, 1], dtype=dtypes.int32) + + dimension_ids = array_ops.constant([0, 1, 1, 2, 0, 2], dtype=dtypes.int64) + bucket_ids = array_ops.constant([-1, 0, 2, 3, -1, 3], dtype=dtypes.int64) + bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1) + + gradients = array_ops.constant([1.8, 2.4, 0.4, 0.2, 8.0, 8.0]) + hessians = array_ops.constant([0.78, 0.4, 0.24, 0.2, 0.26, 0.26]) + bucket_boundaries = array_ops.constant([0.3, 0.52, 0.58, 0.6]) + partitions, gains, splits = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=2, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + bucket_boundaries=bucket_boundaries, + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + feature_column_group_id=0, + bias_feature_id=-1, + class_id=-1, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + partitions, gains, splits = (sess.run([partitions, gains, splits])) + self.assertAllEqual([0, 1], partitions) + self.assertEqual(2, len(splits)) + # Check the split on node 0 - it should split on second dimension + # -(0.2 + 1.2) / (0.12 + 0.2 + 2) + expected_left_weight = -0.603448275862069 + # (0.2 + 1.2) ** 2 / (0.12 + 0.2 + 2) + expected_left_gain = 0.8448275862068965 + # 0.5 / (0.07 + 2) + expected_right_weight = 0.24154589371980678 + # 0.5 ** 2 / (0.07 + 2) + expected_right_gain = 0.12077294685990339 + # (0.2 + 1.2 - 0.5) ** 2 / (0.12 + 0.2 + 0.07 + 2) + expected_bias_gain = 0.3389121338912133 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_right + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + # Split happened on second dimension. + self.assertEqual(1, split_node.split.dimension_id) + + self.assertAllClose(0.58, split_node.split.threshold) + + # Check the split on partition 1. + expected_left_weight = -1.8779342723004695 + expected_right_weight = 0 + + # Verify candidate for partition 1, there's only one active bucket here + # so zero gain is expected. + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_left + + self.assertAllClose(0.0, gains[1]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + self.assertEqual(2, split_node.split.dimension_id) + + self.assertAllClose(0.6, split_node.split.threshold) + def testMakeMulticlassSparseSplit(self): """Tests split handler op.""" with self.test_session() as sess: partition_ids = array_ops.constant([0, 0, 0, 1, 1], dtype=dtypes.int32) - bucket_ids = array_ops.constant([-1, 0, 1, -1, 1], dtype=dtypes.int64) + bucket_ids = array_ops.constant( + [[-1, 0], [0, 0], [1, 0], [-1, 0], [1, 0]], dtype=dtypes.int64) gradients = array_ops.constant([[1.8, 3.5], [2.4, 1.0], [0.4, 4.0], [8.0, 3.1], [8.0, 0.8]]) @@ -317,7 +461,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): gradients = [1.8, 0.4, 2.8, 8.0, 8.0] hessians = [0.78, 0.24, 0.64, 0.26, 0.26] partition_ids = [0, 0, 0, 1, 1] - feature_ids = array_ops.constant([-1, 1, 2, -1, 1], dtype=dtypes.int64) + feature_ids = array_ops.constant( + [[-1, 0], [1, 0], [2, 0], [-1, 0], [1, 0]], dtype=dtypes.int64) partitions, gains, splits = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=2, @@ -412,7 +557,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): hessians = array_ops.constant( [hessian_0, hessian_1, hessian_2, hessian_3, hessian_4]) partition_ids = [0, 0, 0, 1, 1] - feature_ids = array_ops.constant([-1, 1, 2, -1, 1], dtype=dtypes.int64) + feature_ids = array_ops.constant( + [[-1, 0], [1, 0], [2, 0], [-1, 0], [1, 0]], dtype=dtypes.int64) partitions, gains, splits = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=2, @@ -449,7 +595,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): gradients = [] hessians = [] partition_ids = [] - feature_ids = [] + feature_ids = [[]] partitions, gains, splits = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=0, diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py index 0022d4ad52b0699e6706ad04435f09d0d1cd57c3..978bf530cd99ec6af74a49cb96ff98023d7a15cb 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py @@ -38,22 +38,52 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) - op2 = accumulator.add(0, [1], [2], [0.1], [0.2]) + op2 = accumulator.add(0, [1], [[2, 0]], [0.1], [0.2]) with ops.control_dependencies([op1, op2]): - num_updates, partition, feature, grads, hessians = accumulator.flush( + num_updates, partition, bucket_ids, grads, hessians = accumulator.flush( stamp_token=0, next_stamp_token=1) - num_updates, partition, feature, grads, hessians = sess.run( - [num_updates, partition, feature, grads, hessians]) + num_updates, partition, bucket_ids, grads, hessians = sess.run( + [num_updates, partition, bucket_ids, grads, hessians]) - result = _AccumulatorResultToDict(partition, feature, grads, hessians) + result = _AccumulatorResultToDict(partition, bucket_ids, grads, hessians) self.assertEqual(num_updates, 2) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)], [0.2, 0.4]) - self.assertAllClose(result[(2, 3)], [0.3, 0.4]) + # Key is partion, bucket, dimension + self.assertAllClose(result[(1, 2, 0)], [0.2, 0.4]) + self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4]) + + def testMultidimensionalAcculumator(self): + with self.test_session() as sess: + accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=tensor_shape.scalar(), + hessian_shape=tensor_shape.scalar()) + with ops.control_dependencies([accumulator._create_op]): + op1 = accumulator.add( + stamp_token=0, + partition_ids=[1, 2, 1], + feature_ids=[[2, 2], [3, 0], [2, 2]], + gradients=[0.1, 0.3, 0.8], + hessians=[0.2, 0.4, -9]) + op2 = accumulator.add(0, [2, 1], [[3, 1], [2, 2]], [0.1, 1], [0.2, -1]) + + with ops.control_dependencies([op1, op2]): + num_updates, partition, bucket_ids, grads, hessians = accumulator.flush( + stamp_token=0, next_stamp_token=1) + num_updates, partition, bucket_ids, grads, hessians = sess.run( + [num_updates, partition, bucket_ids, grads, hessians]) + + result = _AccumulatorResultToDict(partition, bucket_ids, grads, hessians) + self.assertEqual(num_updates, 2) + self.assertEqual(len(result), 3) + # Key is partion, bucket, dimension. + self.assertAllClose(result[(1, 2, 2)], [1.9, -9.8]) + self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4]) + self.assertAllClose(result[(2, 3, 1)], [0.1, 0.2]) def testDropStaleUpdate(self): with self.test_session() as sess: @@ -65,13 +95,13 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) op2 = accumulator.add( stamp_token=-1, partition_ids=[1], - feature_ids=[2], + feature_ids=[[2, 0]], gradients=[0.1], hessians=[0.2]) @@ -84,8 +114,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(num_updates, 1) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)], [0.1, 0.2]) - self.assertAllClose(result[(2, 3)], [0.3, 0.4]) + self.assertAllClose(result[(1, 2, 0)], [0.1, 0.2]) + self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4]) def testSerialize(self): with self.test_session() as sess: @@ -97,7 +127,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) @@ -123,8 +153,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): self.assertEqual(num_updates, 1) self.assertEqual(num_updates_2, 1) self.assertEqual(len(result_1), 2) - self.assertAllClose(result_1[(1, 2)], [0.1, 0.2]) - self.assertAllClose(result_1[(2, 3)], [0.3, 0.4]) + self.assertAllClose(result_1[(1, 2, 0)], [0.1, 0.2]) + self.assertAllClose(result_1[(2, 3, 0)], [0.3, 0.4]) self.assertAllEqual(result_1, result_2) self.assertEqual(0, stamp_token) @@ -139,18 +169,19 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 1]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) with ops.control_dependencies([op1]): - deserialize = (accumulator.deserialize( - stamp_token=2, - num_updates=3, - partition_ids=[3, 4], - feature_ids=[5, 6], - gradients=[0.4, 0.5], - hessians=[0.6, 0.7])) + deserialize = ( + accumulator.deserialize( + stamp_token=2, + num_updates=3, + partition_ids=[3, 4], + feature_ids=[[5, 0], [6, 2]], + gradients=[0.4, 0.5], + hessians=[0.6, 0.7])) with ops.control_dependencies([deserialize]): num_updates, partition, feature, grads, hessians = accumulator.flush( stamp_token=2, next_stamp_token=3) @@ -161,8 +192,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): hessians) self.assertEqual(num_updates, 3) self.assertEqual(len(result), 2) - self.assertAllClose(result[(3, 5)], [0.4, 0.6]) - self.assertAllClose(result[(4, 6)], [0.5, 0.7]) + self.assertAllClose(result[(3, 5, 0)], [0.4, 0.6]) + self.assertAllClose(result[(4, 6, 2)], [0.5, 0.7]) def testMakeSummary(self): with self.test_session() as sess: @@ -172,15 +203,15 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): hessian_shape=tensor_shape.scalar()) partition, feature, grads, hessians = accumulator._make_summary( partition_ids=[1, 2, 1], - feature_ids=[2, 3, 2], + feature_ids=[[2, 0], [3, 1], [2, 0]], gradients=[0.1, 0.3, 0.1], hessians=[0.2, 0.4, 0.2]) partition, feature, grads, hessians = sess.run( [partition, feature, grads, hessians]) result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)], [0.2, 0.4]) - self.assertAllClose(result[(2, 3)], [0.3, 0.4]) + self.assertAllClose(result[(1, 2, 0)], [0.2, 0.4]) + self.assertAllClose(result[(2, 3, 1)], [0.3, 0.4]) class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): @@ -196,16 +227,54 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], + # Two values for gradients, + gradients=[[0.1, 0.1], [0.2, 0.2]], + # A 2x2 matrix for each hessian. + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) + op2 = accumulator.add( + stamp_token=0, + partition_ids=[1], + feature_ids=[[2, 0]], + gradients=[[0.10, 0.11]], + hessians=[[[0.011, 0.022], [0.033, 0.044]]]) + + with ops.control_dependencies([op1, op2]): + num_updates, partition, feature, grads, hessians = accumulator.flush( + stamp_token=0, next_stamp_token=1) + num_updates, partition, feature, grads, hessians = sess.run( + [num_updates, partition, feature, grads, hessians]) + + result = _AccumulatorResultToDict(partition, feature, grads, hessians) + self.assertEqual(num_updates, 2) + self.assertEqual(len(result), 2) + self.assertAllClose(result[(1, 2, 0)][0], [0.20, 0.21]) + self.assertAllClose(result[(1, 2, 0)][1], + [[0.021, 0.042], [0.063, 0.084]]) + self.assertAllClose(result[(2, 3, 0)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 0)][1], [[0.05, 0.06], [0.07, 0.08]]) + + def testMultidimensionalAcculumator(self): + with self.test_session() as sess: + accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=tensor_shape.TensorShape([2]), + hessian_shape=tensor_shape.TensorShape([2, 2])) + with ops.control_dependencies([accumulator._create_op]): + op1 = accumulator.add( + stamp_token=0, + partition_ids=[1, 2], + feature_ids=[[2, 4], [3, 1]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) op2 = accumulator.add( stamp_token=0, partition_ids=[1], - feature_ids=[2], + feature_ids=[[2, 4]], gradients=[[0.10, 0.11]], hessians=[[[0.011, 0.022], [0.033, 0.044]]]) @@ -218,10 +287,11 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(num_updates, 2) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)][0], [0.20, 0.21]) - self.assertAllClose(result[(1, 2)][1], [[0.021, 0.042], [0.063, 0.084]]) - self.assertAllClose(result[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result[(1, 2, 4)][0], [0.20, 0.21]) + self.assertAllClose(result[(1, 2, 4)][1], + [[0.021, 0.042], [0.063, 0.084]]) + self.assertAllClose(result[(2, 3, 1)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 1)][1], [[0.05, 0.06], [0.07, 0.08]]) def testDropStaleUpdate(self): with self.test_session() as sess: @@ -233,16 +303,16 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 5], [3, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) op2 = accumulator.add( stamp_token=-1, partition_ids=[1], - feature_ids=[2], + feature_ids=[[2, 5]], gradients=[[0.10, 0.11]], hessians=[[[0.011, 0.022], [0.033, 0.044]]]) @@ -255,10 +325,10 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(num_updates, 1) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)][0], [0.1, 0.1]) - self.assertAllClose(result[(1, 2)][1], [[0.01, 0.02], [0.03, 0.04]]) - self.assertAllClose(result[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result[(1, 2, 5)][0], [0.1, 0.1]) + self.assertAllClose(result[(1, 2, 5)][1], [[0.01, 0.02], [0.03, 0.04]]) + self.assertAllClose(result[(2, 3, 0)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 0)][1], [[0.05, 0.06], [0.07, 0.08]]) def testSerialize(self): with self.test_session() as sess: @@ -270,12 +340,12 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) with ops.control_dependencies([op1]): (stamp_token, num_updates_1, partition_1, feature_1, grads_1, @@ -300,15 +370,15 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): self.assertEqual(num_updates_1, 1) self.assertEqual(num_updates_2, 1) self.assertEqual(len(result_1), 2) - self.assertAllClose(result_1[(1, 2)][0], [0.1, 0.1]) - self.assertAllClose(result_1[(1, 2)][1], [[0.01, 0.02], [0.03, 0.04]]) - self.assertAllClose(result_1[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result_1[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result_1[(1, 2, 0)][0], [0.1, 0.1]) + self.assertAllClose(result_1[(1, 2, 0)][1], [[0.01, 0.02], [0.03, 0.04]]) + self.assertAllClose(result_1[(2, 3, 0)][0], [0.2, 0.2]) + self.assertAllClose(result_1[(2, 3, 0)][1], [[0.05, 0.06], [0.07, 0.08]]) - self.assertAllEqual(result_1[1, 2][0], result_2[1, 2][0]) - self.assertAllEqual(result_1[1, 2][1], result_2[1, 2][1]) - self.assertAllEqual(result_1[2, 3][0], result_2[2, 3][0]) - self.assertAllEqual(result_1[2, 3][1], result_2[2, 3][1]) + self.assertAllEqual(result_1[1, 2, 0][0], result_2[1, 2, 0][0]) + self.assertAllEqual(result_1[1, 2, 0][1], result_2[1, 2, 0][1]) + self.assertAllEqual(result_1[2, 3, 0][0], result_2[2, 3, 0][0]) + self.assertAllEqual(result_1[2, 3, 0][1], result_2[2, 3, 0][1]) def testDeserialize(self): with self.test_session() as sess: @@ -321,19 +391,19 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) with ops.control_dependencies([op1]): deserialize = accumulator.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], - feature_ids=[4, 5], + feature_ids=[[4, 0], [5, 0]], # Two values for gradients, gradients=[[0.3, 0.3], [0.5, 0.5]], # A 2x2 matrix for each hessian. @@ -349,10 +419,10 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): hessians) self.assertEqual(num_updates, 3) self.assertEqual(len(result), 2) - self.assertAllClose(result[(3, 4)][0], [0.3, 0.3]) - self.assertAllClose(result[(3, 4)][1], [[0.03, 0.04], [0.05, 0.06]]) - self.assertAllClose(result[(4, 5)][0], [0.5, 0.5]) - self.assertAllClose(result[(4, 5)][1], [[0.07, 0.08], [0.09, 0.10]]) + self.assertAllClose(result[(3, 4, 0)][0], [0.3, 0.3]) + self.assertAllClose(result[(3, 4, 0)][1], [[0.03, 0.04], [0.05, 0.06]]) + self.assertAllClose(result[(4, 5, 0)][0], [0.5, 0.5]) + self.assertAllClose(result[(4, 5, 0)][1], [[0.07, 0.08], [0.09, 0.10]]) def testMakeSummary(self): with self.test_session() as sess: @@ -362,7 +432,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): hessian_shape=tensor_shape.TensorShape([2, 2])) partition, feature, grads, hessians = accumulator._make_summary( partition_ids=[1, 2, 1], - feature_ids=[2, 3, 2], + feature_ids=[[2, 0], [3, 2], [2, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2], [0.10, 0.11]], # A 2x2 matrix for each hessian. @@ -373,15 +443,16 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)][0], [0.20, 0.21]) - self.assertAllClose(result[(1, 2)][1], [[0.021, 0.042], [0.063, 0.084]]) - self.assertAllClose(result[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result[(1, 2, 0)][0], [0.20, 0.21]) + self.assertAllClose(result[(1, 2, 0)][1], + [[0.021, 0.042], [0.063, 0.084]]) + self.assertAllClose(result[(2, 3, 2)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 2)][1], [[0.05, 0.06], [0.07, 0.08]]) def _AccumulatorResultToDict(partition, feature, grads, hessians): """Converts the inputs to a dictionary since the ordering changes.""" - return {(partition[i], feature[i]): (grads[i], hessians[i]) + return {(partition[i], feature[i, 0], feature[i, 1]): (grads[i], hessians[i]) for i in range(len(partition))} diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index f0413fee5a8249d15f2cdae095dc7fa2c76a22b8..c2e65b643df90e88aadb0bb9acaf692da35b1a16 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -181,7 +181,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): tree_weights: 1.0 tree_metadata { num_layers_grown: 1 - is_finalized: true } growing_metadata { num_trees_attempted: 1 @@ -189,7 +188,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): } """ self.assertEqual(new_stamp, 1) - self.assertEqual(stats.num_trees, 1) + self.assertEqual(stats.num_trees, 0) self.assertEqual(stats.num_layers, 1) self.assertEqual(stats.active_tree, 1) self.assertEqual(stats.active_layer, 1) @@ -231,7 +230,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): tree_weights: 1.0 tree_metadata { num_layers_grown: 1 - is_finalized: true } growing_metadata { num_trees_attempted: 1 @@ -239,7 +237,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): } """ self.assertEqual(new_stamp, 2) - self.assertEqual(stats.num_trees, 1) + self.assertEqual(stats.num_trees, 0) self.assertEqual(stats.num_layers, 1) self.assertEqual(stats.active_tree, 1) self.assertEqual(stats.active_layer, 1) diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 7e8e15e7d8c89d1adaa472b1da7e8bb3c73ca17e..294e04002adac62fc123a3242a05a1b36f422433 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -45,6 +45,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token, epsilon, num_quantiles, + max_elements=None, name=None, container=None): """Creates a QuantileAccumulator object. @@ -53,6 +54,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token: The initial value for the stamp token. epsilon: Error bound on the quantile computation. num_quantiles: Number of quantiles to produce from the final summary. + max_elements: Maximum number of elements added to the accumulator. name: the name to save the accumulator under. container: An optional `string`. Defaults to `""` """ @@ -67,6 +69,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): self._quantile_accumulator_handle, init_stamp_token, epsilon=epsilon, + max_elements=max_elements, num_quantiles=num_quantiles) is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( self._quantile_accumulator_handle) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index cebe3474ca9251971c23bde9e82564189c1ee624..b95956dae2a62b28643cd31815c5f5650eca337b 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -322,9 +322,11 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._attempted_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False) + initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + name="attempted_trees") self._finalized_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False) + initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + name="finalized_trees") if not features: raise ValueError("Features dictionary must be specified.") (fc_names, dense_floats, sparse_float_indices, sparse_float_values, @@ -739,7 +741,7 @@ class GradientBoostedDecisionTreeModel(object): # Accumulate a step after updating stats. batch_size = math_ops.cast(array_ops.shape(labels)[0], dtypes.float32) with ops.control_dependencies(stats_update_ops): - add_step_op = steps_accumulator.add(ensemble_stamp, [0], [0], + add_step_op = steps_accumulator.add(ensemble_stamp, [0], [[0, 0]], [batch_size], [1.0]) # Determine learning rate. @@ -892,7 +894,9 @@ class GradientBoostedDecisionTreeModel(object): # Accumulate gradients and hessians. partition_ids = math_ops.range(self._logits_dimension) - feature_ids = array_ops.zeros_like(partition_ids, dtype=dtypes.int64) + feature_ids = array_ops.zeros( + [self._logits_dimension, 2], dtype=dtypes.int64) + add_stats_op = bias_stats_accumulator.add( ensemble_stamp, partition_ids, feature_ids, grads_sum, hess_sum) return control_flow_ops.group(*[add_stats_op], name="update_bias_stats") diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 16e24d97ddee0751e0b808b89080074c1b4baba7..dba51d4f527792d2a8dedc693f74c07119fd231d 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -912,8 +912,10 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertEqual(1, len(output.trees[0].nodes[2].leaf.sparse_vector.index)) self.assertEqual(3, output.trees[0].nodes[2].leaf.sparse_vector.index[0]) - self.assertAlmostEqual( - 0.893284678459, output.trees[0].nodes[2].leaf.sparse_vector.value[0]) + self.assertAllClose( + 0.893284678459, + output.trees[0].nodes[2].leaf.sparse_vector.value[0], + atol=1e-4, rtol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py index dde16426863b60e9df64da1ee6b36caec273bfd6..ccb8509c0347f9c9b6f1e8f4f620230aac9a6c2d 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as np from tensorflow.contrib.boosted_trees.python.utils import losses @@ -60,35 +58,27 @@ class LossesTest(test_util.TensorFlowTestCase): neg_loss = loss_for_negatives.eval() # For positive labels, points <= 0.3 get max loss of e. # For negative labels, these points have minimum loss of 1/e. - for i in range(2): - self.assertAlmostEqual(math.exp(1), pos_loss[i], places=4) - self.assertAlmostEqual(math.exp(-1), neg_loss[i], places=4) + self.assertAllClose(np.exp(np.ones([2, 1])), pos_loss[:2], atol=1e-4) + self.assertAllClose(np.exp(-np.ones([2, 1])), neg_loss[:2], atol=1e-4) # For positive lables, p oints with predictions 0.7 and larger get minimum # loss value of 1/e. For negative labels, these points are wrongly # classified and get loss e. - for i in range(6, 10): - self.assertAlmostEqual(math.exp(-1), pos_loss[i], places=4) - self.assertAlmostEqual(math.exp(1), neg_loss[i], places=4) + self.assertAllClose(np.exp(-np.ones([4, 1])), pos_loss[6:10], atol=1e-4) + self.assertAllClose(np.exp(np.ones([4, 1])), neg_loss[6:10], atol=1e-4) # Points in between 0.5-eps, 0..5+eps get loss exp(-label_m*y), where # y = 1/eps *x -1/(2eps), where x is the probability and label_m is either # 1 or -1 (for label of 0). - for i in range(2, 6): - self.assertAlmostEqual( - math.exp(-1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)), - pos_loss[i], - places=4) - self.assertAlmostEqual( - math.exp(1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)), - neg_loss[i], - places=4) + self.assertAllClose( + np.exp(-(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps)), + pos_loss[2:6], atol=1e-4) + self.assertAllClose( + np.exp(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps), + neg_loss[2:6], atol=1e-4) def test_per_example_squared_loss(self): - def _squared_loss(p, y): - return np.mean(1.0 * (p - y) * (p - y)) - labels = np.array([[0.123], [224.2], [-3], [2], [.3]], dtype=np.float32) weights = array_ops.ones([5, 1], dtypes.float32) predictions = np.array( @@ -99,9 +89,8 @@ class LossesTest(test_util.TensorFlowTestCase): predictions) loss = loss_tensor.eval() - for i in range(5): - self.assertAlmostEqual( - _squared_loss(labels[i], predictions[i]), loss[i], places=4) + self.assertAllClose( + np.square(labels[:5] - predictions[:5]), loss[:5], atol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index aa8f5ed12bc6f779e3c1a923b9225ec283189747..fe8bd072afd43a64fa62a65bd8900b5a98dbe761 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -60,9 +60,7 @@ tf_py_test( size = "small", srcs = ["python/ops/bigquery_reader_ops_test.py"], additional_deps = [ - ":bigquery_reader_ops_op_lib", ":cloud_py", - "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc index b31b882fa19a7eaad304d6d423961234f9affef4..e9b79a066def566096d6c3f3745974423e3371d1 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc @@ -421,7 +421,7 @@ TEST_F(BigQueryTableAccessorTest, MultiplePagesTest) { TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); EXPECT_EQ(3, row_id); EXPECT_TRUE(accessor_->Done()); - + Example expected_example; ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProtoWithNulls, &expected_example)); diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index f0144e9faa26801b6491b242b04fda8905f15306..c74da9cabd6816bc9c7891e32937534cff2d677d 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -80,13 +80,9 @@ class TPUClusterResolver(ClusterResolver): raise ImportError('googleapiclient must be installed before using the ' 'TPU cluster resolver') - # TODO(b/67375680): Remove custom URL once TPU APIs are finalized self._service = discovery.build( - 'tpu', - 'v1', - credentials=self._credentials, - discoveryServiceUrl='https://storage.googleapis.com' - '/tpu-api-definition/v1alpha1.json') + 'tpu', 'v1alpha1', + credentials=self._credentials) else: self._service = service diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 77a3fc0c8322117f50265e56952b68480583de02..89c1c86d68a9c7d9c8513850903b92b64afa6064 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -18,7 +18,6 @@ cmake_policy(SET CMP0022 NEW) # Options option(tensorflow_VERBOSE "Enable for verbose output" OFF) -option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF) option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) @@ -34,6 +33,12 @@ option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF) option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) + +# GPU, CUDA and cuDNN options +option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) +option(tensorflow_CUDA_VERSION "CUDA version to build against" 9.0) +option(tensorflow_CUDNN_VERSION "cuDNN version to build against" 7) + if(HAIKU) option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" OFF) else() @@ -53,7 +58,15 @@ if (NOT WIN32) set(tensorflow_CUDNN_INCLUDE /usr/include) endif (NOT tensorflow_CUDNN_INCLUDE) option(tensorflow_PATH_CUDNN_STATIC_LIB "Override PATH_STATIC_LIB for libcudnn_static.a" ${tensorflow_PATH_STATIC_LIB}) + if (NOT tensorflow_PATH_CUDNN_STATIC_LIB) + # option's default value is OFF. Fill it with real default values + set (tensorflow_PATH_CUDNN_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) + endif (NOT tensorflow_PATH_CUDNN_STATIC_LIB) option(tensorflow_PATH_NCCL_STATIC_LIB "Override PATH_STATIC_LIB for libnccl_static.a" ${tensorflow_PATH_STATIC_LIB}) + if (NOT tensorflow_PATH_NCCL_STATIC_LIB) + # option's default value is OFF. Fill it with real default values + set (tensorflow_PATH_NCCL_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) + endif (NOT tensorflow_PATH_NCCL_STATIC_LIB) option(tensorflow_CUDA_LIBRARY_PATH "Designate the default CUDA library paths" /usr/local/cuda/lib64) if (NOT tensorflow_CUDA_LIBRARY_PATH) # option's default value is OFF. Fill it with real default values @@ -262,7 +275,7 @@ if (tensorflow_ENABLE_GPU) list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") endif (NOT WIN32) - find_package(CUDA 8.0 REQUIRED) + find_package(CUDA ${tensorflow_CUDA_VERSION} REQUIRED) # by default we assume compute cabability 3.5 and 5.2. If you change this change it in # CUDA_NVCC_FLAGS and cuda_config.h below @@ -316,13 +329,16 @@ if (tensorflow_ENABLE_GPU) ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) endif (WIN32) + # Remove "." from CUDA version variable. + string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) + # create cuda_config.h FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h "#ifndef CUDA_CUDA_CONFIG_H_\n" "#define CUDA_CUDA_CONFIG_H_\n" "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n" - "#define TF_CUDA_VERSION \"64_80\"\n" - "#define TF_CUDNN_VERSION \"64_6\"\n" + "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" + "#define TF_CUDNN_VERSION \"64_${tensorflow_CUDNN_VERSION}\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" "#endif // CUDA_CUDA_CONFIG_H_\n" ) @@ -360,15 +376,15 @@ if (tensorflow_ENABLE_GPU) if(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value msvcp_dll_name=msvcp140.dll - cudart_dll_name=cudart64_80.dll - cuda_version_number=8.0 + cudart_dll_name=cudart64_${short_CUDA_VER}.dll + cuda_version_number=${tensorflow_CUDA_VERSION} nvcuda_dll_name=nvcuda.dll - cudnn_dll_name=cudnn64_6.dll - cudnn_version_number=6) + cudnn_dll_name=cudnn64_${tensorflow_CUDNN_VERSION}.dll + cudnn_version_number=${tensorflow_CUDNN_VERSION}) else(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value - cuda_version_number=8.0 - cudnn_version_number=6) + cuda_version_number=${tensorflow_CUDA_VERSION} + cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value @@ -383,11 +399,7 @@ endif() # Let's get to work! include(tf_core_framework.cmake) -# NOTE: Disabled until issue #3996 is fixed. -# include(tf_stream_executor.cmake) -if (tensorflow_ENABLE_GPU) - include(tf_stream_executor.cmake) -endif() +include(tf_stream_executor.cmake) include(tf_core_cpu.cmake) include(tf_core_ops.cmake) diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 4ddfec5960d2b759bacb376202cd8dab6ef2b024..4be733a2809f366a214fa2bb853bccffb10ecaba 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -19,23 +19,6 @@ for instructions on how to install a pre-built TensorFlow package on Windows. ### Current known limitations * It is not possible to load a custom Op library. * GCS file system is not supported. -* The following Ops are not currently implemented: - - Dequantize - - QuantizeAndDequantize - - QuantizedAvgPool - - QuantizedBatchNomWithGlobalNormalization - - QuantizedBiasAdd - - QuantizedConcat - - QuantizedConv2D - - QuantizedMatmul - - QuantizedMaxPoo - - QuantizeDownAndShrinkRange - - QuantizedRelu - - QuantizedRelu6 - - QuantizedReshape - - QuantizeV2 - - RequantizationRange - - Requantize ## Building with CMake diff --git a/tensorflow/contrib/cmake/external/gemmlowp.cmake b/tensorflow/contrib/cmake/external/gemmlowp.cmake index 3b146657bfc9bdd54db14839195af45972e67aff..a235442dc5c0a07e249653381436eeae81575883 100644 --- a/tensorflow/contrib/cmake/external/gemmlowp.cmake +++ b/tensorflow/contrib/cmake/external/gemmlowp.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(gemmlowp_URL https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip) -set(gemmlowp_HASH SHA256=dd2557072bde12141419cb8320a9c25e6ec41a8ae53c2ac78c076a347bb46d9d) +set(gemmlowp_URL https://github.com/google/gemmlowp/archive/6a2a90822e8546fc2bfa7044de0faf1c1cb4862f.zip) +set(gemmlowp_HASH SHA256=3447948d219f3270383766bbe08942888c0eb4e0ca6663c0e0548502ec5bb77d) set(gemmlowp_BUILD ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) set(gemmlowp_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 464aad74c6c8623981338695af01b026dcc0e6e3..41ea0b48a4600d7ca2dd2f4a61c14ec0cc5b4734 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 781fd6f6ea03645a520cd5c675da67ab61f87e4b) +set(GRPC_TAG 54e8f37e537794c2d814c1604c1282125f64f093) if(WIN32) set(grpc_STATIC_LIBRARIES @@ -28,10 +28,11 @@ else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/libcares.a) + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) endif() +add_definitions(-DGRPC_ARES=0) + ExternalProject_Add(grpc PREFIX grpc DEPENDS protobuf zlib @@ -39,9 +40,6 @@ ExternalProject_Add(grpc GIT_TAG ${GRPC_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 - # TODO(jhseu): Remove this PATCH_COMMAND once grpc removes the dependency - # on "grpc" from the "grpc++_unsecure" rule. - PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/grpc/CMakeLists.txt ${GRPC_BUILD} BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin INSTALL_COMMAND "" diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 2c42377f5078d55e72e37eb5e880624bc09ddef0..05080060479b6240edb8ab9f65160b3dd182feb9 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 394e71f0ebeed6788ae6c84d42c1bedf6e1ee9f7) +set(nsync_TAG 8502189abfa44c249c01c2cad64e6ed660a9a668) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/external/re2.cmake b/tensorflow/contrib/cmake/external/re2.cmake index b56f4b089813247f3ab1c751538ba4b05cacb5b6..d10f5959f71dd350e6e2bcb81be8882b203fb231 100644 --- a/tensorflow/contrib/cmake/external/re2.cmake +++ b/tensorflow/contrib/cmake/external/re2.cmake @@ -45,4 +45,5 @@ ExternalProject_Add(re2 endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX:STRING=${re2_INSTALL} + -DRE2_BUILD_TESTING:BOOL=OFF ) diff --git a/tensorflow/contrib/cmake/patches/grpc/CMakeLists.txt b/tensorflow/contrib/cmake/patches/grpc/CMakeLists.txt deleted file mode 100644 index 84722c5ca2a9f9253c7a76dd610dde615a176c07..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/cmake/patches/grpc/CMakeLists.txt +++ /dev/null @@ -1,14415 +0,0 @@ -# GRPC global cmake file -# This currently builds C and C++ code. -# This file has been automatically generated from a template file. -# Please look at the templates directory instead. -# This file can be regenerated from the template by running -# tools/buildgen/generate_projects.sh -# -# Copyright 2015 gRPC authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - - -cmake_minimum_required(VERSION 2.8) - -set(PACKAGE_NAME "grpc") -set(PACKAGE_VERSION "1.5.0-dev") -set(PACKAGE_STRING "${PACKAGE_NAME} ${PACKAGE_VERSION}") -set(PACKAGE_TARNAME "${PACKAGE_NAME}-${PACKAGE_VERSION}") -set(PACKAGE_BUGREPORT "https://github.com/grpc/grpc/issues/") -project(${PACKAGE_NAME} C CXX) - -set(gRPC_INSTALL_BINDIR "${CMAKE_INSTALL_PREFIX}/bin" CACHE PATH "Installation directory for executables") -set(gRPC_INSTALL_LIBDIR "${CMAKE_INSTALL_PREFIX}/lib" CACHE PATH "Installation directory for libraries") -set(gRPC_INSTALL_INCLUDEDIR "${CMAKE_INSTALL_PREFIX}/include" CACHE PATH "Installation directory for headers") -set(gRPC_INSTALL_CMAKEDIR "${CMAKE_INSTALL_PREFIX}/lib/cmake/${PACKAGE_NAME}" CACHE PATH "Installation directory for cmake config files") - -# Options -option(gRPC_BUILD_TESTS "Build tests" OFF) - -set(gRPC_INSTALL_default ON) -if (NOT CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) - # Disable gRPC_INSTALL by default if building as a submodule - set(gRPC_INSTALL_default OFF) -endif() -set(gRPC_INSTALL ${gRPC_INSTALL_default} CACHE BOOL - "Generate installation target: gRPC_ZLIB_PROVIDER, gRPC_CARES_PROVIDER, gRPC_SSL_PROVIDER and gRPC_PROTOBUF_PROVIDER must all be \"package\"") - -set(gRPC_ZLIB_PROVIDER "module" CACHE STRING "Provider of zlib library") -set_property(CACHE gRPC_ZLIB_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_CARES_PROVIDER "module" CACHE STRING "Provider of c-ares library") -set_property(CACHE gRPC_CARES_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_SSL_PROVIDER "module" CACHE STRING "Provider of ssl library") -set_property(CACHE gRPC_SSL_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_PROTOBUF_PROVIDER "module" CACHE STRING "Provider of protobuf library") -set_property(CACHE gRPC_PROTOBUF_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_PROTOBUF_PACKAGE_TYPE "" CACHE STRING "Algorithm for searching protobuf package") -set_property(CACHE gRPC_PROTOBUF_PACKAGE_TYPE PROPERTY STRINGS "CONFIG" "MODULE") - -set(gRPC_GFLAGS_PROVIDER "module" CACHE STRING "Provider of gflags library") -set_property(CACHE gRPC_GFLAGS_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_BENCHMARK_PROVIDER "module" CACHE STRING "Provider of benchmark library") -set_property(CACHE gRPC_BENCHMARK_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_USE_PROTO_LITE OFF CACHE BOOL "Use the protobuf-lite library") - -if(UNIX) - if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") - set(_gRPC_PLATFORM_LINUX ON) - elseif(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - set(_gRPC_PLATFORM_MAC ON) - else() - set(_gRPC_PLATFORM_POSIX ON) - endif() -endif() -if(WIN32) - set(_gRPC_PLATFORM_WINDOWS ON) -endif() - -set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) - -if (MSVC) - include(cmake/msvc_static_runtime.cmake) - add_definitions(-D_WIN32_WINNT=0x600 -D_SCL_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_WARNINGS -D_WINSOCK_DEPRECATED_NO_WARNINGS) - # needed to compile protobuf - add_definitions(/wd4065 /wd4506) - # TODO(jtattermusch): revisit C4267 occurrences throughout the code - add_definitions(/wd4267) -endif() - -if (gRPC_USE_PROTO_LITE) - set(_gRPC_PROTOBUF_LIBRARY_NAME "libprotobuf-lite") - add_definitions("-DGRPC_USE_PROTO_LITE") -else() - set(_gRPC_PROTOBUF_LIBRARY_NAME "libprotobuf") -endif() - -if("${gRPC_ZLIB_PROVIDER}" STREQUAL "module") - if(NOT ZLIB_ROOT_DIR) - set(ZLIB_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/zlib) - endif() - set(ZLIB_INCLUDE_DIR "${ZLIB_ROOT_DIR}") - if(EXISTS "${ZLIB_ROOT_DIR}/CMakeLists.txt") - # TODO(jtattermusch): workaround for https://github.com/madler/zlib/issues/218 - include_directories(${ZLIB_INCLUDE_DIR}) - - add_subdirectory(${ZLIB_ROOT_DIR} third_party/zlib) - if(TARGET zlibstatic) - set(_gRPC_ZLIB_LIBRARIES zlibstatic) - endif() - else() - message(WARNING "gRPC_ZLIB_PROVIDER is \"module\" but ZLIB_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_ZLIB_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_ZLIB_PROVIDER}" STREQUAL "package") - find_package(ZLIB) - if(TARGET ZLIB::ZLIB) - set(_gRPC_ZLIB_LIBRARIES ZLIB::ZLIB) - endif() - set(_gRPC_FIND_ZLIB "if(NOT ZLIB_FOUND)\n find_package(ZLIB)\nendif()") -endif() - -if("${gRPC_CARES_PROVIDER}" STREQUAL "module") - if(NOT CARES_ROOT_DIR) - set(CARES_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src/c-ares) - endif() - string(TOLOWER ${CMAKE_SYSTEM_NAME} CARES_SYSTEM_NAME) - set(CARES_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cares/cares") - set(CARES_BUILD_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cares") - set(CARES_PLATFORM_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cares/config_${CARES_SYSTEM_NAME}") - if(EXISTS "${CARES_ROOT_DIR}/CMakeLists.txt") - if("${CARES_SYSTEM_NAME}" MATCHES "windows") - add_definitions(-DCARES_STATICLIB=1) - add_definitions(-DWIN32_LEAN_AND_MEAN=1) - else() - add_definitions(-DHAVE_CONFIG_H=1) - add_definitions(-D_GNU_SOURCE=1) - endif() - add_subdirectory(src/c-ares third_party/cares) - if(TARGET cares) - set(_gRPC_CARES_LIBRARIES cares) - endif() - else() - message(WARNING "gRPC_CARES_PROVIDER is \"module\" but CARES_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_CARES_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_CARES_PROVIDER}" STREQUAL "package") - find_package(c-ares CONFIG) - if(TARGET c-ares::cares) - set(_gRPC_CARES_LIBRARIES c-ares::cares) - endif() - set(_gRPC_FIND_CARES "if(NOT c-ares_FOUND)\n find_package(c-ares CONFIG)\nendif()") -endif() - -if("${gRPC_PROTOBUF_PROVIDER}" STREQUAL "module") - # Building the protobuf tests require gmock what is not part of a standard protobuf checkout. - # Disable them unless they are explicitly requested from the cmake command line (when we assume - # gmock is downloaded to the right location inside protobuf). - if(NOT protobuf_BUILD_TESTS) - set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests") - endif() - # Disable building protobuf with zlib. Building protobuf with zlib breaks - # the build if zlib is not installed on the system. - if(NOT protobuf_WITH_ZLIB) - set(protobuf_WITH_ZLIB OFF CACHE BOOL "Build protobuf with zlib.") - endif() - if(NOT PROTOBUF_ROOT_DIR) - set(PROTOBUF_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/protobuf) - endif() - set(PROTOBUF_WELLKNOWN_IMPORT_DIR ${PROTOBUF_ROOT_DIR}/src) - if(EXISTS "${PROTOBUF_ROOT_DIR}/cmake/CMakeLists.txt") - set(protobuf_MSVC_STATIC_RUNTIME OFF CACHE BOOL "Link static runtime libraries") - add_subdirectory(${PROTOBUF_ROOT_DIR}/cmake third_party/protobuf) - if(TARGET ${_gRPC_PROTOBUF_LIBRARY_NAME}) - set(_gRPC_PROTOBUF_LIBRARIES ${_gRPC_PROTOBUF_LIBRARY_NAME}) - endif() - if(TARGET libprotoc) - set(_gRPC_PROTOBUF_PROTOC_LIBRARIES libprotoc) - endif() - if(TARGET protoc) - set(_gRPC_PROTOBUF_PROTOC protoc) - endif() - else() - message(WARNING "gRPC_PROTOBUF_PROVIDER is \"module\" but PROTOBUF_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_PROTOBUF_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_PROTOBUF_PROVIDER}" STREQUAL "package") - find_package(Protobuf ${gRPC_PROTOBUF_PACKAGE_TYPE}) - if(Protobuf_FOUND OR PROTOBUF_FOUND) - if(TARGET protobuf::${_gRPC_PROTOBUF_LIBRARY_NAME}) - set(_gRPC_PROTOBUF_LIBRARIES protobuf::${_gRPC_PROTOBUF_LIBRARY_NAME}) - else() - set(_gRPC_PROTOBUF_LIBRARIES ${PROTOBUF_LIBRARIES}) - endif() - if(TARGET protobuf::libprotoc) - set(_gRPC_PROTOBUF_PROTOC_LIBRARIES protobuf::libprotoc) - else() - set(_gRPC_PROTOBUF_PROTOC_LIBRARIES ${PROTOBUF_PROTOC_LIBRARIES}) - endif() - if(TARGET protobuf::protoc) - set(_gRPC_PROTOBUF_PROTOC protobuf::protoc) - else() - set(_gRPC_PROTOBUF_PROTOC ${PROTOBUF_PROTOC_EXECUTABLE}) - endif() - set(_gRPC_FIND_PROTOBUF "if(NOT Protobuf_FOUND AND NOT PROTOBUF_FOUND)\n find_package(Protobuf ${gRPC_PROTOBUF_PACKAGE_TYPE})\nendif()") - endif() - if(PROTOBUF_FOUND) - include_directories(${PROTOBUF_INCLUDE_DIRS}) - endif() - set(PROTOBUF_WELLKNOWN_IMPORT_DIR /usr/local/include) -endif() - -if("${gRPC_SSL_PROVIDER}" STREQUAL "module") - if(NOT BORINGSSL_ROOT_DIR) - set(BORINGSSL_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/boringssl) - endif() - if(EXISTS "${BORINGSSL_ROOT_DIR}/CMakeLists.txt") - set(OPENSSL_NO_ASM ON) # make boringssl buildable with Visual Studio - add_subdirectory(${BORINGSSL_ROOT_DIR} third_party/boringssl) - if(TARGET ssl) - set(_gRPC_SSL_LIBRARIES ssl) - endif() - else() - message(WARNING "gRPC_SSL_PROVIDER is \"module\" but BORINGSSL_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_SSL_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_SSL_PROVIDER}" STREQUAL "package") - find_package(OpenSSL) - if(TARGET OpenSSL::SSL) - set(_gRPC_SSL_LIBRARIES OpenSSL::SSL) - endif() - set(_gRPC_FIND_SSL "if(NOT OpenSSL_FOUND)\n find_package(OpenSSL)\nendif()") -endif() - -if("${gRPC_GFLAGS_PROVIDER}" STREQUAL "module") - if(NOT GFLAGS_ROOT_DIR) - set(GFLAGS_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/gflags) - endif() - if(EXISTS "${GFLAGS_ROOT_DIR}/CMakeLists.txt") - add_subdirectory(${GFLAGS_ROOT_DIR} third_party/gflags) - if(TARGET gflags_static) - set(_gRPC_GFLAGS_LIBRARIES gflags_static) - endif() - else() - message(WARNING "gRPC_GFLAGS_PROVIDER is \"module\" but GFLAGS_ROOT_DIR is wrong") - endif() -elseif("${gRPC_GFLAGS_PROVIDER}" STREQUAL "package") - find_package(gflags) - if(TARGET gflags::gflags) - set(_gRPC_GFLAGS_LIBRARIES gflags::gflags) - endif() - set(_gRPC_FIND_GFLAGS "if(NOT gflags_FOUND)\n find_package(gflags)\nendif()") -endif() - -if("${gRPC_BENCHMARK_PROVIDER}" STREQUAL "module") - if(NOT BENCHMARK_ROOT_DIR) - set(BENCHMARK_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/benchmark) - endif() - if(EXISTS "${BENCHMARK_ROOT_DIR}/CMakeLists.txt") - add_subdirectory(${BENCHMARK_ROOT_DIR} third_party/benchmark) - if(TARGET benchmark) - set(_gRPC_BENCHMARK_LIBRARIES benchmark) - endif() - else() - message(WARNING "gRPC_BENCHMARK_PROVIDER is \"module\" but BENCHMARK_ROOT_DIR is wrong") - endif() -elseif("${gRPC_BENCHMARK_PROVIDER}" STREQUAL "package") - find_package(benchmark) - if(TARGET benchmark::benchmark) - set(_gRPC_BENCHMARK_LIBRARIES benchmark::benchmark) - endif() - set(_gRPC_FIND_BENCHMARK "if(NOT benchmark_FOUND)\n find_package(benchmark)\nendif()") -endif() - -if(NOT MSVC) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") -endif() - -if(_gRPC_PLATFORM_MAC) - set(_gRPC_ALLTARGETS_LIBRARIES ${CMAKE_DL_LIBS} m pthread) -elseif(UNIX) - set(_gRPC_ALLTARGETS_LIBRARIES ${CMAKE_DL_LIBS} rt m pthread) -endif() - -if(WIN32 AND MSVC) - set(_gRPC_BASELIB_LIBRARIES wsock32 ws2_32) -endif() - -# Create directory for generated .proto files -set(_gRPC_PROTO_GENS_DIR ${CMAKE_BINARY_DIR}/gens) -file(MAKE_DIRECTORY ${_gRPC_PROTO_GENS_DIR}) - -# protobuf_generate_grpc_cpp -# -------------------------- -# -# Add custom commands to process ``.proto`` files to C++ using protoc and -# GRPC plugin:: -# -# protobuf_generate_grpc_cpp [...] -# -# ``ARGN`` -# ``.proto`` files -# -function(protobuf_generate_grpc_cpp) - if(NOT ARGN) - message(SEND_ERROR "Error: PROTOBUF_GENERATE_GRPC_CPP() called without any proto files") - return() - endif() - - set(_protobuf_include_path -I . -I ${PROTOBUF_WELLKNOWN_IMPORT_DIR}) - foreach(FIL ${ARGN}) - get_filename_component(ABS_FIL ${FIL} ABSOLUTE) - get_filename_component(FIL_WE ${FIL} NAME_WE) - file(RELATIVE_PATH REL_FIL ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL}) - get_filename_component(REL_DIR ${REL_FIL} DIRECTORY) - set(RELFIL_WE "${REL_DIR}/${FIL_WE}") - - add_custom_command( - OUTPUT "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.cc" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.h" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}_mock.grpc.pb.h" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.cc" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.h" - COMMAND $ - ARGS --grpc_out=generate_mock_code=true:${_gRPC_PROTO_GENS_DIR} - --cpp_out=${_gRPC_PROTO_GENS_DIR} - --plugin=protoc-gen-grpc=$ - ${_protobuf_include_path} - ${REL_FIL} - DEPENDS ${ABS_FIL} ${_gRPC_PROTOBUF_PROTOC} grpc_cpp_plugin - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - COMMENT "Running gRPC C++ protocol buffer compiler on ${FIL}" - VERBATIM) - - set_source_files_properties("${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.cc" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.h" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}_mock.grpc.pb.h" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.cc" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.h" PROPERTIES GENERATED TRUE) - endforeach() -endfunction() - -add_custom_target(plugins - DEPENDS - grpc_cpp_plugin - grpc_csharp_plugin - grpc_node_plugin - grpc_objective_c_plugin - grpc_php_plugin - grpc_python_plugin - grpc_ruby_plugin -) - -add_custom_target(tools_c - DEPENDS - check_epollexclusive - gen_hpack_tables - gen_legal_metadata_characters - gen_percent_encoding_tables - grpc_create_jwt - grpc_print_google_default_creds_token - grpc_verify_jwt -) - -add_custom_target(tools_cxx - DEPENDS -) - -add_custom_target(tools - DEPENDS tools_c tools_cxx) - -if (gRPC_BUILD_TESTS) -add_custom_target(buildtests_c) -add_dependencies(buildtests_c alarm_test) -add_dependencies(buildtests_c algorithm_test) -add_dependencies(buildtests_c alloc_test) -add_dependencies(buildtests_c alpn_test) -add_dependencies(buildtests_c arena_test) -add_dependencies(buildtests_c bad_server_response_test) -add_dependencies(buildtests_c bdp_estimator_test) -add_dependencies(buildtests_c bin_decoder_test) -add_dependencies(buildtests_c bin_encoder_test) -add_dependencies(buildtests_c census_context_test) -add_dependencies(buildtests_c census_intrusive_hash_map_test) -add_dependencies(buildtests_c census_resource_test) -add_dependencies(buildtests_c census_trace_context_test) -add_dependencies(buildtests_c channel_create_test) -add_dependencies(buildtests_c chttp2_hpack_encoder_test) -add_dependencies(buildtests_c chttp2_stream_map_test) -add_dependencies(buildtests_c chttp2_varint_test) -add_dependencies(buildtests_c combiner_test) -add_dependencies(buildtests_c compression_test) -add_dependencies(buildtests_c concurrent_connectivity_test) -add_dependencies(buildtests_c connection_refused_test) -add_dependencies(buildtests_c dns_resolver_connectivity_test) -add_dependencies(buildtests_c dns_resolver_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c dualstack_socket_test) -endif() -add_dependencies(buildtests_c endpoint_pair_test) -add_dependencies(buildtests_c error_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c ev_epollsig_linux_test) -endif() -add_dependencies(buildtests_c fake_resolver_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fd_conservation_posix_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fd_posix_test) -endif() -add_dependencies(buildtests_c fling_client) -add_dependencies(buildtests_c fling_server) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fling_stream_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fling_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c goaway_server_test) -endif() -add_dependencies(buildtests_c gpr_avl_test) -add_dependencies(buildtests_c gpr_backoff_test) -add_dependencies(buildtests_c gpr_cmdline_test) -add_dependencies(buildtests_c gpr_cpu_test) -add_dependencies(buildtests_c gpr_env_test) -add_dependencies(buildtests_c gpr_histogram_test) -add_dependencies(buildtests_c gpr_host_port_test) -add_dependencies(buildtests_c gpr_log_test) -add_dependencies(buildtests_c gpr_mpscq_test) -add_dependencies(buildtests_c gpr_spinlock_test) -add_dependencies(buildtests_c gpr_stack_lockfree_test) -add_dependencies(buildtests_c gpr_string_test) -add_dependencies(buildtests_c gpr_sync_test) -add_dependencies(buildtests_c gpr_thd_test) -add_dependencies(buildtests_c gpr_time_test) -add_dependencies(buildtests_c gpr_tls_test) -add_dependencies(buildtests_c gpr_useful_test) -add_dependencies(buildtests_c grpc_auth_context_test) -add_dependencies(buildtests_c grpc_b64_test) -add_dependencies(buildtests_c grpc_byte_buffer_reader_test) -add_dependencies(buildtests_c grpc_channel_args_test) -add_dependencies(buildtests_c grpc_channel_stack_test) -add_dependencies(buildtests_c grpc_completion_queue_test) -add_dependencies(buildtests_c grpc_completion_queue_threading_test) -add_dependencies(buildtests_c grpc_credentials_test) -add_dependencies(buildtests_c grpc_fetch_oauth2) -add_dependencies(buildtests_c grpc_invalid_channel_args_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c grpc_json_token_test) -endif() -add_dependencies(buildtests_c grpc_jwt_verifier_test) -add_dependencies(buildtests_c grpc_security_connector_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c handshake_client) -endif() -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c handshake_server) -endif() -add_dependencies(buildtests_c hpack_parser_test) -add_dependencies(buildtests_c hpack_table_test) -add_dependencies(buildtests_c http_parser_test) -add_dependencies(buildtests_c httpcli_format_request_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c httpcli_test) -endif() -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c httpscli_test) -endif() -add_dependencies(buildtests_c init_test) -add_dependencies(buildtests_c invalid_call_argument_test) -add_dependencies(buildtests_c json_rewrite) -add_dependencies(buildtests_c json_rewrite_test) -add_dependencies(buildtests_c json_stream_error_test) -add_dependencies(buildtests_c json_test) -add_dependencies(buildtests_c lame_client_test) -add_dependencies(buildtests_c lb_policies_test) -add_dependencies(buildtests_c load_file_test) -add_dependencies(buildtests_c memory_profile_client) -add_dependencies(buildtests_c memory_profile_server) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c memory_profile_test) -endif() -add_dependencies(buildtests_c message_compress_test) -add_dependencies(buildtests_c minimal_stack_is_minimal_test) -add_dependencies(buildtests_c mlog_test) -add_dependencies(buildtests_c multiple_server_queues_test) -add_dependencies(buildtests_c murmur_hash_test) -add_dependencies(buildtests_c no_server_test) -add_dependencies(buildtests_c num_external_connectivity_watchers_test) -add_dependencies(buildtests_c parse_address_test) -add_dependencies(buildtests_c percent_encoding_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c pollset_set_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c resolve_address_posix_test) -endif() -add_dependencies(buildtests_c resolve_address_test) -add_dependencies(buildtests_c resource_quota_test) -add_dependencies(buildtests_c secure_channel_create_test) -add_dependencies(buildtests_c secure_endpoint_test) -add_dependencies(buildtests_c sequential_connectivity_test) -add_dependencies(buildtests_c server_chttp2_test) -add_dependencies(buildtests_c server_test) -add_dependencies(buildtests_c slice_buffer_test) -add_dependencies(buildtests_c slice_hash_table_test) -add_dependencies(buildtests_c slice_string_helpers_test) -add_dependencies(buildtests_c slice_test) -add_dependencies(buildtests_c sockaddr_resolver_test) -add_dependencies(buildtests_c sockaddr_utils_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c socket_utils_test) -endif() -add_dependencies(buildtests_c status_conversion_test) -add_dependencies(buildtests_c stream_compression_test) -add_dependencies(buildtests_c stream_owned_slice_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c tcp_client_posix_test) -endif() -add_dependencies(buildtests_c tcp_client_uv_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c tcp_posix_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c tcp_server_posix_test) -endif() -add_dependencies(buildtests_c tcp_server_uv_test) -add_dependencies(buildtests_c time_averaged_stats_test) -add_dependencies(buildtests_c timeout_encoding_test) -add_dependencies(buildtests_c timer_heap_test) -add_dependencies(buildtests_c timer_list_test) -add_dependencies(buildtests_c transport_connectivity_state_test) -add_dependencies(buildtests_c transport_metadata_test) -add_dependencies(buildtests_c transport_pid_controller_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c transport_security_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c udp_server_test) -endif() -add_dependencies(buildtests_c uri_parser_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c wakeup_fd_cv_test) -endif() -add_dependencies(buildtests_c public_headers_must_be_c89) -add_dependencies(buildtests_c badreq_bad_client_test) -add_dependencies(buildtests_c connection_prefix_bad_client_test) -add_dependencies(buildtests_c head_of_line_blocking_bad_client_test) -add_dependencies(buildtests_c headers_bad_client_test) -add_dependencies(buildtests_c initial_settings_frame_bad_client_test) -add_dependencies(buildtests_c large_metadata_bad_client_test) -add_dependencies(buildtests_c server_registered_method_bad_client_test) -add_dependencies(buildtests_c simple_request_bad_client_test) -add_dependencies(buildtests_c unknown_frame_bad_client_test) -add_dependencies(buildtests_c window_overflow_bad_client_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c bad_ssl_cert_server) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c bad_ssl_cert_test) -endif() -add_dependencies(buildtests_c h2_census_test) -add_dependencies(buildtests_c h2_compress_test) -add_dependencies(buildtests_c h2_fakesec_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_fd_test) -endif() -add_dependencies(buildtests_c h2_full_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c h2_full+pipe_test) -endif() -add_dependencies(buildtests_c h2_full+trace_test) -add_dependencies(buildtests_c h2_full+workarounds_test) -add_dependencies(buildtests_c h2_http_proxy_test) -add_dependencies(buildtests_c h2_load_reporting_test) -add_dependencies(buildtests_c h2_oauth2_test) -add_dependencies(buildtests_c h2_proxy_test) -add_dependencies(buildtests_c h2_sockpair_test) -add_dependencies(buildtests_c h2_sockpair+trace_test) -add_dependencies(buildtests_c h2_sockpair_1byte_test) -add_dependencies(buildtests_c h2_ssl_test) -add_dependencies(buildtests_c h2_ssl_cert_test) -add_dependencies(buildtests_c h2_ssl_proxy_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_uds_test) -endif() -add_dependencies(buildtests_c inproc_test) -add_dependencies(buildtests_c h2_census_nosec_test) -add_dependencies(buildtests_c h2_compress_nosec_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_fd_nosec_test) -endif() -add_dependencies(buildtests_c h2_full_nosec_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c h2_full+pipe_nosec_test) -endif() -add_dependencies(buildtests_c h2_full+trace_nosec_test) -add_dependencies(buildtests_c h2_full+workarounds_nosec_test) -add_dependencies(buildtests_c h2_http_proxy_nosec_test) -add_dependencies(buildtests_c h2_load_reporting_nosec_test) -add_dependencies(buildtests_c h2_proxy_nosec_test) -add_dependencies(buildtests_c h2_sockpair_nosec_test) -add_dependencies(buildtests_c h2_sockpair+trace_nosec_test) -add_dependencies(buildtests_c h2_sockpair_1byte_nosec_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_uds_nosec_test) -endif() -add_dependencies(buildtests_c inproc_nosec_test) -add_dependencies(buildtests_c api_fuzzer_one_entry) -add_dependencies(buildtests_c client_fuzzer_one_entry) -add_dependencies(buildtests_c hpack_parser_fuzzer_test_one_entry) -add_dependencies(buildtests_c http_request_fuzzer_test_one_entry) -add_dependencies(buildtests_c http_response_fuzzer_test_one_entry) -add_dependencies(buildtests_c json_fuzzer_test_one_entry) -add_dependencies(buildtests_c nanopb_fuzzer_response_test_one_entry) -add_dependencies(buildtests_c nanopb_fuzzer_serverlist_test_one_entry) -add_dependencies(buildtests_c percent_decode_fuzzer_one_entry) -add_dependencies(buildtests_c percent_encode_fuzzer_one_entry) -add_dependencies(buildtests_c server_fuzzer_one_entry) -add_dependencies(buildtests_c ssl_server_fuzzer_one_entry) -add_dependencies(buildtests_c uri_fuzzer_test_one_entry) - -add_custom_target(buildtests_cxx) -add_dependencies(buildtests_cxx alarm_cpp_test) -add_dependencies(buildtests_cxx async_end2end_test) -add_dependencies(buildtests_cxx auth_property_iterator_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_arena) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_call_create) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_chttp2_hpack) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_chttp2_transport) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_closure) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_cq) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_cq_multiple_threads) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_error) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_streaming_ping_pong) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_streaming_pump) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_trickle) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_unary_ping_pong) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_metadata) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_pollset) -endif() -add_dependencies(buildtests_cxx channel_arguments_test) -add_dependencies(buildtests_cxx channel_filter_test) -add_dependencies(buildtests_cxx cli_call_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx client_crash_test) -endif() -add_dependencies(buildtests_cxx client_crash_test_server) -add_dependencies(buildtests_cxx client_lb_end2end_test) -add_dependencies(buildtests_cxx codegen_test_full) -add_dependencies(buildtests_cxx codegen_test_minimal) -add_dependencies(buildtests_cxx credentials_test) -add_dependencies(buildtests_cxx cxx_byte_buffer_test) -add_dependencies(buildtests_cxx cxx_slice_test) -add_dependencies(buildtests_cxx cxx_string_ref_test) -add_dependencies(buildtests_cxx cxx_time_test) -add_dependencies(buildtests_cxx end2end_test) -add_dependencies(buildtests_cxx error_details_test) -add_dependencies(buildtests_cxx filter_end2end_test) -add_dependencies(buildtests_cxx generic_end2end_test) -add_dependencies(buildtests_cxx golden_file_test) -add_dependencies(buildtests_cxx grpc_cli) -add_dependencies(buildtests_cxx grpc_tool_test) -add_dependencies(buildtests_cxx grpclb_api_test) -add_dependencies(buildtests_cxx grpclb_end2end_test) -add_dependencies(buildtests_cxx grpclb_test) -add_dependencies(buildtests_cxx health_service_end2end_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx http2_client) -endif() -add_dependencies(buildtests_cxx hybrid_end2end_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx interop_client) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx interop_server) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx interop_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx json_run_localhost) -endif() -add_dependencies(buildtests_cxx memory_test) -add_dependencies(buildtests_cxx metrics_client) -add_dependencies(buildtests_cxx mock_test) -add_dependencies(buildtests_cxx noop-benchmark) -add_dependencies(buildtests_cxx proto_server_reflection_test) -add_dependencies(buildtests_cxx proto_utils_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx qps_interarrival_test) -endif() -add_dependencies(buildtests_cxx qps_json_driver) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx qps_openloop_test) -endif() -add_dependencies(buildtests_cxx qps_worker) -add_dependencies(buildtests_cxx reconnect_interop_client) -add_dependencies(buildtests_cxx reconnect_interop_server) -add_dependencies(buildtests_cxx secure_auth_context_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx secure_sync_unary_ping_pong_test) -endif() -add_dependencies(buildtests_cxx server_builder_plugin_test) -add_dependencies(buildtests_cxx server_builder_test) -add_dependencies(buildtests_cxx server_context_test_spouse_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx server_crash_test) -endif() -add_dependencies(buildtests_cxx server_crash_test_client) -add_dependencies(buildtests_cxx server_request_call_test) -add_dependencies(buildtests_cxx shutdown_test) -add_dependencies(buildtests_cxx status_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx streaming_throughput_test) -endif() -add_dependencies(buildtests_cxx stress_test) -add_dependencies(buildtests_cxx thread_manager_test) -add_dependencies(buildtests_cxx thread_stress_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx writes_per_rpc_test) -endif() - -add_custom_target(buildtests - DEPENDS buildtests_c buildtests_cxx) -endif (gRPC_BUILD_TESTS) - - -add_library(gpr - src/core/lib/profiling/basic_timers.c - src/core/lib/profiling/stap_timers.c - src/core/lib/support/alloc.c - src/core/lib/support/arena.c - src/core/lib/support/atm.c - src/core/lib/support/avl.c - src/core/lib/support/backoff.c - src/core/lib/support/cmdline.c - src/core/lib/support/cpu_iphone.c - src/core/lib/support/cpu_linux.c - src/core/lib/support/cpu_posix.c - src/core/lib/support/cpu_windows.c - src/core/lib/support/env_linux.c - src/core/lib/support/env_posix.c - src/core/lib/support/env_windows.c - src/core/lib/support/histogram.c - src/core/lib/support/host_port.c - src/core/lib/support/log.c - src/core/lib/support/log_android.c - src/core/lib/support/log_linux.c - src/core/lib/support/log_posix.c - src/core/lib/support/log_windows.c - src/core/lib/support/mpscq.c - src/core/lib/support/murmur_hash.c - src/core/lib/support/stack_lockfree.c - src/core/lib/support/string.c - src/core/lib/support/string_posix.c - src/core/lib/support/string_util_windows.c - src/core/lib/support/string_windows.c - src/core/lib/support/subprocess_posix.c - src/core/lib/support/subprocess_windows.c - src/core/lib/support/sync.c - src/core/lib/support/sync_posix.c - src/core/lib/support/sync_windows.c - src/core/lib/support/thd.c - src/core/lib/support/thd_posix.c - src/core/lib/support/thd_windows.c - src/core/lib/support/time.c - src/core/lib/support/time_posix.c - src/core/lib/support/time_precise.c - src/core/lib/support/time_windows.c - src/core/lib/support/tls_pthread.c - src/core/lib/support/tmpfile_msys.c - src/core/lib/support/tmpfile_posix.c - src/core/lib/support/tmpfile_windows.c - src/core/lib/support/wrap_memcpy.c -) - -if(WIN32 AND MSVC) - set_target_properties(gpr PROPERTIES COMPILE_PDB_NAME "gpr" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/gpr.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(gpr - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr - ${_gRPC_ALLTARGETS_LIBRARIES} -) - -foreach(_hdr - include/grpc/support/alloc.h - include/grpc/support/atm.h - include/grpc/support/atm_gcc_atomic.h - include/grpc/support/atm_gcc_sync.h - include/grpc/support/atm_windows.h - include/grpc/support/avl.h - include/grpc/support/cmdline.h - include/grpc/support/cpu.h - include/grpc/support/histogram.h - include/grpc/support/host_port.h - include/grpc/support/log.h - include/grpc/support/log_windows.h - include/grpc/support/port_platform.h - include/grpc/support/string_util.h - include/grpc/support/subprocess.h - include/grpc/support/sync.h - include/grpc/support/sync_generic.h - include/grpc/support/sync_posix.h - include/grpc/support/sync_windows.h - include/grpc/support/thd.h - include/grpc/support/time.h - include/grpc/support/tls.h - include/grpc/support/tls_gcc.h - include/grpc/support/tls_msvc.h - include/grpc/support/tls_pthread.h - include/grpc/support/useful.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS gpr EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(gpr_test_util - test/core/util/test_config.c -) - -if(WIN32 AND MSVC) - set_target_properties(gpr_test_util PROPERTIES COMPILE_PDB_NAME "gpr_test_util" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/gpr_test_util.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(gpr_test_util - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_test_util - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc - src/core/lib/surface/init.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/lib/http/httpcli_security_connector.c - src/core/lib/security/context/security_context.c - src/core/lib/security/credentials/composite/composite_credentials.c - src/core/lib/security/credentials/credentials.c - src/core/lib/security/credentials/credentials_metadata.c - src/core/lib/security/credentials/fake/fake_credentials.c - src/core/lib/security/credentials/google_default/credentials_generic.c - src/core/lib/security/credentials/google_default/google_default_credentials.c - src/core/lib/security/credentials/iam/iam_credentials.c - src/core/lib/security/credentials/jwt/json_token.c - src/core/lib/security/credentials/jwt/jwt_credentials.c - src/core/lib/security/credentials/jwt/jwt_verifier.c - src/core/lib/security/credentials/oauth2/oauth2_credentials.c - src/core/lib/security/credentials/plugin/plugin_credentials.c - src/core/lib/security/credentials/ssl/ssl_credentials.c - src/core/lib/security/transport/client_auth_filter.c - src/core/lib/security/transport/lb_targets_info.c - src/core/lib/security/transport/secure_endpoint.c - src/core/lib/security/transport/security_connector.c - src/core/lib/security/transport/security_handshaker.c - src/core/lib/security/transport/server_auth_filter.c - src/core/lib/security/transport/tsi_error.c - src/core/lib/security/util/json_util.c - src/core/lib/surface/init_secure.c - src/core/tsi/fake_transport_security.c - src/core/tsi/gts_transport_security.c - src/core/tsi/ssl_transport_security.c - src/core/tsi/transport_security.c - src/core/tsi/transport_security_adapter.c - src/core/ext/transport/chttp2/server/chttp2_server.c - src/core/ext/transport/chttp2/client/secure/secure_channel_create.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c - src/core/ext/transport/chttp2/client/insecure/channel_create.c - src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c - src/core/ext/transport/inproc/inproc_plugin.c - src/core/ext/transport/inproc/inproc_transport.c - src/core/ext/filters/client_channel/lb_policy/grpclb/client_load_reporting_filter.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.c - src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.c - src/core/ext/filters/client_channel/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.c - src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_posix.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_fallback.c - src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.c - src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.c - src/core/ext/filters/load_reporting/load_reporting.c - src/core/ext/filters/load_reporting/load_reporting_filter.c - src/core/ext/census/base_resources.c - src/core/ext/census/context.c - src/core/ext/census/gen/census.pb.c - src/core/ext/census/gen/trace_context.pb.c - src/core/ext/census/grpc_context.c - src/core/ext/census/grpc_filter.c - src/core/ext/census/grpc_plugin.c - src/core/ext/census/initialize.c - src/core/ext/census/intrusive_hash_map.c - src/core/ext/census/mlog.c - src/core/ext/census/operation.c - src/core/ext/census/placeholders.c - src/core/ext/census/resource.c - src/core/ext/census/trace_context.c - src/core/ext/census/tracing.c - src/core/ext/filters/max_age/max_age_filter.c - src/core/ext/filters/message_size/message_size_filter.c - src/core/ext/filters/workarounds/workaround_cronet_compression_filter.c - src/core/ext/filters/workarounds/workaround_utils.c - src/core/plugin_registry/grpc_plugin_registry.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc PROPERTIES COMPILE_PDB_NAME "grpc" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ZLIB_LIBRARIES} - ${_gRPC_CARES_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/grpc_security.h - include/grpc/census.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_library(grpc_cronet - src/core/lib/surface/init.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/cronet/client/secure/cronet_channel_create.c - src/core/ext/transport/cronet/transport/cronet_api_dummy.c - src/core/ext/transport/cronet/transport/cronet_transport.c - src/core/ext/transport/chttp2/client/secure/secure_channel_create.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/lib/http/httpcli_security_connector.c - src/core/lib/security/context/security_context.c - src/core/lib/security/credentials/composite/composite_credentials.c - src/core/lib/security/credentials/credentials.c - src/core/lib/security/credentials/credentials_metadata.c - src/core/lib/security/credentials/fake/fake_credentials.c - src/core/lib/security/credentials/google_default/credentials_generic.c - src/core/lib/security/credentials/google_default/google_default_credentials.c - src/core/lib/security/credentials/iam/iam_credentials.c - src/core/lib/security/credentials/jwt/json_token.c - src/core/lib/security/credentials/jwt/jwt_credentials.c - src/core/lib/security/credentials/jwt/jwt_verifier.c - src/core/lib/security/credentials/oauth2/oauth2_credentials.c - src/core/lib/security/credentials/plugin/plugin_credentials.c - src/core/lib/security/credentials/ssl/ssl_credentials.c - src/core/lib/security/transport/client_auth_filter.c - src/core/lib/security/transport/lb_targets_info.c - src/core/lib/security/transport/secure_endpoint.c - src/core/lib/security/transport/security_connector.c - src/core/lib/security/transport/security_handshaker.c - src/core/lib/security/transport/server_auth_filter.c - src/core/lib/security/transport/tsi_error.c - src/core/lib/security/util/json_util.c - src/core/lib/surface/init_secure.c - src/core/tsi/fake_transport_security.c - src/core/tsi/gts_transport_security.c - src/core/tsi/ssl_transport_security.c - src/core/tsi/transport_security.c - src/core/tsi/transport_security_adapter.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/filters/load_reporting/load_reporting.c - src/core/ext/filters/load_reporting/load_reporting_filter.c - src/core/plugin_registry/grpc_cronet_plugin_registry.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_cronet PROPERTIES COMPILE_PDB_NAME "grpc_cronet" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_cronet.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_cronet - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_cronet - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ZLIB_LIBRARIES} - ${_gRPC_CARES_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/grpc_cronet.h - include/grpc/grpc_security.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc_cronet EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc_test_util - test/core/end2end/data/client_certs.c - test/core/end2end/data/server1_cert.c - test/core/end2end/data/server1_key.c - test/core/end2end/data/test_root_cert.c - test/core/security/oauth2_utils.c - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - test/core/end2end/cq_verifier.c - test/core/end2end/fixtures/http_proxy_fixture.c - test/core/end2end/fixtures/proxy.c - test/core/iomgr/endpoint_tests.c - test/core/util/debugger_macros.c - test/core/util/grpc_profiler.c - test/core/util/memory_counters.c - test/core/util/mock_endpoint.c - test/core/util/parse_hexstring.c - test/core/util/passthru_endpoint.c - test/core/util/port.c - test/core/util/port_server_client.c - test/core/util/slice_splitter.c - test/core/util/trickle_endpoint.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_test_util PROPERTIES COMPILE_PDB_NAME "grpc_test_util" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_test_util.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_test_util - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_test_util - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr - grpc -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(grpc_test_util_unsecure - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - test/core/end2end/cq_verifier.c - test/core/end2end/fixtures/http_proxy_fixture.c - test/core/end2end/fixtures/proxy.c - test/core/iomgr/endpoint_tests.c - test/core/util/debugger_macros.c - test/core/util/grpc_profiler.c - test/core/util/memory_counters.c - test/core/util/mock_endpoint.c - test/core/util/parse_hexstring.c - test/core/util/passthru_endpoint.c - test/core/util/port.c - test/core/util/port_server_client.c - test/core/util/slice_splitter.c - test/core/util/trickle_endpoint.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_test_util_unsecure PROPERTIES COMPILE_PDB_NAME "grpc_test_util_unsecure" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_test_util_unsecure.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_test_util_unsecure - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_test_util_unsecure - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - gpr_test_util - grpc_unsecure - grpc -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc_unsecure - src/core/lib/surface/init.c - src/core/lib/surface/init_unsecure.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/ext/transport/chttp2/server/chttp2_server.c - src/core/ext/transport/chttp2/client/insecure/channel_create.c - src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/ext/transport/inproc/inproc_plugin.c - src/core/ext/transport/inproc/inproc_transport.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_posix.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_fallback.c - src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.c - src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.c - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - src/core/ext/filters/load_reporting/load_reporting.c - src/core/ext/filters/load_reporting/load_reporting_filter.c - src/core/ext/filters/client_channel/lb_policy/grpclb/client_load_reporting_filter.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.c - src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.c - src/core/ext/filters/client_channel/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.c - src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.c - src/core/ext/census/base_resources.c - src/core/ext/census/context.c - src/core/ext/census/gen/census.pb.c - src/core/ext/census/gen/trace_context.pb.c - src/core/ext/census/grpc_context.c - src/core/ext/census/grpc_filter.c - src/core/ext/census/grpc_plugin.c - src/core/ext/census/initialize.c - src/core/ext/census/intrusive_hash_map.c - src/core/ext/census/mlog.c - src/core/ext/census/operation.c - src/core/ext/census/placeholders.c - src/core/ext/census/resource.c - src/core/ext/census/trace_context.c - src/core/ext/census/tracing.c - src/core/ext/filters/max_age/max_age_filter.c - src/core/ext/filters/message_size/message_size_filter.c - src/core/ext/filters/workarounds/workaround_cronet_compression_filter.c - src/core/ext/filters/workarounds/workaround_utils.c - src/core/plugin_registry/grpc_unsecure_plugin_registry.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_unsecure PROPERTIES COMPILE_PDB_NAME "grpc_unsecure" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_unsecure.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_unsecure - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_unsecure - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_ZLIB_LIBRARIES} - ${_gRPC_CARES_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/census.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc_unsecure EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(reconnect_server - test/core/util/reconnect_server.c -) - -if(WIN32 AND MSVC) - set_target_properties(reconnect_server PROPERTIES COMPILE_PDB_NAME "reconnect_server" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/reconnect_server.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(reconnect_server - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(reconnect_server - ${_gRPC_ALLTARGETS_LIBRARIES} - test_tcp_server - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(test_tcp_server - test/core/util/test_tcp_server.c -) - -if(WIN32 AND MSVC) - set_target_properties(test_tcp_server PROPERTIES COMPILE_PDB_NAME "test_tcp_server" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/test_tcp_server.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(test_tcp_server - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(test_tcp_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc++ - src/cpp/client/insecure_credentials.cc - src/cpp/client/secure_credentials.cc - src/cpp/common/auth_property_iterator.cc - src/cpp/common/secure_auth_context.cc - src/cpp/common/secure_channel_arguments.cc - src/cpp/common/secure_create_auth_context.cc - src/cpp/server/insecure_server_credentials.cc - src/cpp/server/secure_server_credentials.cc - src/cpp/client/channel_cc.cc - src/cpp/client/client_context.cc - src/cpp/client/create_channel.cc - src/cpp/client/create_channel_internal.cc - src/cpp/client/create_channel_posix.cc - src/cpp/client/credentials_cc.cc - src/cpp/client/generic_stub.cc - src/cpp/common/channel_arguments.cc - src/cpp/common/channel_filter.cc - src/cpp/common/completion_queue_cc.cc - src/cpp/common/core_codegen.cc - src/cpp/common/resource_quota_cc.cc - src/cpp/common/rpc_method.cc - src/cpp/common/version_cc.cc - src/cpp/server/async_generic_service.cc - src/cpp/server/channel_argument_option.cc - src/cpp/server/create_default_thread_pool.cc - src/cpp/server/dynamic_thread_pool.cc - src/cpp/server/health/default_health_check_service.cc - src/cpp/server/health/health.pb.c - src/cpp/server/health/health_check_service.cc - src/cpp/server/health/health_check_service_server_builder_option.cc - src/cpp/server/server_builder.cc - src/cpp/server/server_cc.cc - src/cpp/server/server_context.cc - src/cpp/server/server_credentials.cc - src/cpp/server/server_posix.cc - src/cpp/thread_manager/thread_manager.cc - src/cpp/util/byte_buffer_cc.cc - src/cpp/util/slice_cc.cc - src/cpp/util/status.cc - src/cpp/util/string_ref.cc - src/cpp/util/time_cc.cc - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/cpp/codegen/codegen_init.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++ PROPERTIES COMPILE_PDB_NAME "grpc++" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++ - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++ - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc -) - -foreach(_hdr - include/grpc++/alarm.h - include/grpc++/channel.h - include/grpc++/client_context.h - include/grpc++/completion_queue.h - include/grpc++/create_channel.h - include/grpc++/create_channel_posix.h - include/grpc++/ext/health_check_service_server_builder_option.h - include/grpc++/generic/async_generic_service.h - include/grpc++/generic/generic_stub.h - include/grpc++/grpc++.h - include/grpc++/health_check_service_interface.h - include/grpc++/impl/call.h - include/grpc++/impl/channel_argument_option.h - include/grpc++/impl/client_unary_call.h - include/grpc++/impl/codegen/core_codegen.h - include/grpc++/impl/grpc_library.h - include/grpc++/impl/method_handler_impl.h - include/grpc++/impl/rpc_method.h - include/grpc++/impl/rpc_service_method.h - include/grpc++/impl/serialization_traits.h - include/grpc++/impl/server_builder_option.h - include/grpc++/impl/server_builder_plugin.h - include/grpc++/impl/server_initializer.h - include/grpc++/impl/service_type.h - include/grpc++/resource_quota.h - include/grpc++/security/auth_context.h - include/grpc++/security/auth_metadata_processor.h - include/grpc++/security/credentials.h - include/grpc++/security/server_credentials.h - include/grpc++/server.h - include/grpc++/server_builder.h - include/grpc++/server_context.h - include/grpc++/server_posix.h - include/grpc++/support/async_stream.h - include/grpc++/support/async_unary_call.h - include/grpc++/support/byte_buffer.h - include/grpc++/support/channel_arguments.h - include/grpc++/support/config.h - include/grpc++/support/slice.h - include/grpc++/support/status.h - include/grpc++/support/status_code_enum.h - include/grpc++/support/string_ref.h - include/grpc++/support/stub_options.h - include/grpc++/support/sync_stream.h - include/grpc++/support/time.h - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc++/impl/codegen/proto_utils.h - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++ EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_library(grpc++_cronet - src/cpp/client/cronet_credentials.cc - src/cpp/client/insecure_credentials.cc - src/cpp/common/insecure_create_auth_context.cc - src/cpp/server/insecure_server_credentials.cc - src/cpp/client/channel_cc.cc - src/cpp/client/client_context.cc - src/cpp/client/create_channel.cc - src/cpp/client/create_channel_internal.cc - src/cpp/client/create_channel_posix.cc - src/cpp/client/credentials_cc.cc - src/cpp/client/generic_stub.cc - src/cpp/common/channel_arguments.cc - src/cpp/common/channel_filter.cc - src/cpp/common/completion_queue_cc.cc - src/cpp/common/core_codegen.cc - src/cpp/common/resource_quota_cc.cc - src/cpp/common/rpc_method.cc - src/cpp/common/version_cc.cc - src/cpp/server/async_generic_service.cc - src/cpp/server/channel_argument_option.cc - src/cpp/server/create_default_thread_pool.cc - src/cpp/server/dynamic_thread_pool.cc - src/cpp/server/health/default_health_check_service.cc - src/cpp/server/health/health.pb.c - src/cpp/server/health/health_check_service.cc - src/cpp/server/health/health_check_service_server_builder_option.cc - src/cpp/server/server_builder.cc - src/cpp/server/server_cc.cc - src/cpp/server/server_context.cc - src/cpp/server/server_credentials.cc - src/cpp/server/server_posix.cc - src/cpp/thread_manager/thread_manager.cc - src/cpp/util/byte_buffer_cc.cc - src/cpp/util/slice_cc.cc - src/cpp/util/status.cc - src/cpp/util/string_ref.cc - src/cpp/util/time_cc.cc - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/cpp/codegen/codegen_init.cc - src/core/ext/transport/chttp2/client/insecure/channel_create.c - src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c - src/core/ext/transport/chttp2/server/chttp2_server.c - src/core/ext/census/base_resources.c - src/core/ext/census/context.c - src/core/ext/census/gen/census.pb.c - src/core/ext/census/gen/trace_context.pb.c - src/core/ext/census/grpc_context.c - src/core/ext/census/grpc_filter.c - src/core/ext/census/grpc_plugin.c - src/core/ext/census/initialize.c - src/core/ext/census/intrusive_hash_map.c - src/core/ext/census/mlog.c - src/core/ext/census/operation.c - src/core/ext/census/placeholders.c - src/core/ext/census/resource.c - src/core/ext/census/trace_context.c - src/core/ext/census/tracing.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_cronet PROPERTIES COMPILE_PDB_NAME "grpc++_cronet" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_cronet.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++_cronet - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_cronet - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - grpc_cronet - grpc -) - -foreach(_hdr - include/grpc++/alarm.h - include/grpc++/channel.h - include/grpc++/client_context.h - include/grpc++/completion_queue.h - include/grpc++/create_channel.h - include/grpc++/create_channel_posix.h - include/grpc++/ext/health_check_service_server_builder_option.h - include/grpc++/generic/async_generic_service.h - include/grpc++/generic/generic_stub.h - include/grpc++/grpc++.h - include/grpc++/health_check_service_interface.h - include/grpc++/impl/call.h - include/grpc++/impl/channel_argument_option.h - include/grpc++/impl/client_unary_call.h - include/grpc++/impl/codegen/core_codegen.h - include/grpc++/impl/grpc_library.h - include/grpc++/impl/method_handler_impl.h - include/grpc++/impl/rpc_method.h - include/grpc++/impl/rpc_service_method.h - include/grpc++/impl/serialization_traits.h - include/grpc++/impl/server_builder_option.h - include/grpc++/impl/server_builder_plugin.h - include/grpc++/impl/server_initializer.h - include/grpc++/impl/service_type.h - include/grpc++/resource_quota.h - include/grpc++/security/auth_context.h - include/grpc++/security/auth_metadata_processor.h - include/grpc++/security/credentials.h - include/grpc++/security/server_credentials.h - include/grpc++/server.h - include/grpc++/server_builder.h - include/grpc++/server_context.h - include/grpc++/server_posix.h - include/grpc++/support/async_stream.h - include/grpc++/support/async_unary_call.h - include/grpc++/support/byte_buffer.h - include/grpc++/support/channel_arguments.h - include/grpc++/support/config.h - include/grpc++/support/slice.h - include/grpc++/support/status.h - include/grpc++/support/status_code_enum.h - include/grpc++/support/string_ref.h - include/grpc++/support/stub_options.h - include/grpc++/support/sync_stream.h - include/grpc++/support/time.h - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/census.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_cronet EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_library(grpc++_error_details - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.grpc.pb.h - src/cpp/util/error_details.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_error_details PROPERTIES COMPILE_PDB_NAME "grpc++_error_details" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_error_details.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/status/status.proto -) - -target_include_directories(grpc++_error_details - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_error_details - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ -) - -foreach(_hdr - include/grpc++/support/error_details.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_error_details EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc++_proto_reflection_desc_db - test/cpp/util/proto_reflection_descriptor_database.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_proto_reflection_desc_db PROPERTIES COMPILE_PDB_NAME "grpc++_proto_reflection_desc_db" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_proto_reflection_desc_db.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/reflection/v1alpha/reflection.proto -) - -target_include_directories(grpc++_proto_reflection_desc_db - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_proto_reflection_desc_db - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc -) - -foreach(_hdr - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) - -add_library(grpc++_reflection - src/cpp/ext/proto_server_reflection.cc - src/cpp/ext/proto_server_reflection_plugin.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_reflection PROPERTIES COMPILE_PDB_NAME "grpc++_reflection" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_reflection.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/reflection/v1alpha/reflection.proto -) - -target_include_directories(grpc++_reflection - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_reflection - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc -) - -foreach(_hdr - include/grpc++/ext/proto_server_reflection_plugin.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_reflection EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc++_test_config - test/cpp/util/test_config_cc.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_test_config PROPERTIES COMPILE_PDB_NAME "grpc++_test_config" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_test_config.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++_test_config - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_test_config - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(grpc++_test_util - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_mock.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h - test/cpp/end2end/test_service_impl.cc - test/cpp/util/byte_buffer_proto_helper.cc - test/cpp/util/create_test_channel.cc - test/cpp/util/string_ref_helper.cc - test/cpp/util/subprocess.cc - test/cpp/util/test_credentials_provider.cc - src/cpp/codegen/codegen_init.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_test_util PROPERTIES COMPILE_PDB_NAME "grpc++_test_util" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_test_util.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/health/v1/health.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/duplicate/echo_duplicate.proto -) - -target_include_directories(grpc++_test_util - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_test_util - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc_test_util - grpc -) - -foreach(_hdr - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc++/impl/codegen/proto_utils.h - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) - -add_library(grpc++_unsecure - src/cpp/client/insecure_credentials.cc - src/cpp/common/insecure_create_auth_context.cc - src/cpp/server/insecure_server_credentials.cc - src/cpp/client/channel_cc.cc - src/cpp/client/client_context.cc - src/cpp/client/create_channel.cc - src/cpp/client/create_channel_internal.cc - src/cpp/client/create_channel_posix.cc - src/cpp/client/credentials_cc.cc - src/cpp/client/generic_stub.cc - src/cpp/common/channel_arguments.cc - src/cpp/common/channel_filter.cc - src/cpp/common/completion_queue_cc.cc - src/cpp/common/core_codegen.cc - src/cpp/common/resource_quota_cc.cc - src/cpp/common/rpc_method.cc - src/cpp/common/version_cc.cc - src/cpp/server/async_generic_service.cc - src/cpp/server/channel_argument_option.cc - src/cpp/server/create_default_thread_pool.cc - src/cpp/server/dynamic_thread_pool.cc - src/cpp/server/health/default_health_check_service.cc - src/cpp/server/health/health.pb.c - src/cpp/server/health/health_check_service.cc - src/cpp/server/health/health_check_service_server_builder_option.cc - src/cpp/server/server_builder.cc - src/cpp/server/server_cc.cc - src/cpp/server/server_context.cc - src/cpp/server/server_credentials.cc - src/cpp/server/server_posix.cc - src/cpp/thread_manager/thread_manager.cc - src/cpp/util/byte_buffer_cc.cc - src/cpp/util/slice_cc.cc - src/cpp/util/status.cc - src/cpp/util/string_ref.cc - src/cpp/util/time_cc.cc - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/cpp/codegen/codegen_init.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_unsecure PROPERTIES COMPILE_PDB_NAME "grpc++_unsecure" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_unsecure.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++_unsecure - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_unsecure - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - grpc_unsecure -) - -foreach(_hdr - include/grpc++/alarm.h - include/grpc++/channel.h - include/grpc++/client_context.h - include/grpc++/completion_queue.h - include/grpc++/create_channel.h - include/grpc++/create_channel_posix.h - include/grpc++/ext/health_check_service_server_builder_option.h - include/grpc++/generic/async_generic_service.h - include/grpc++/generic/generic_stub.h - include/grpc++/grpc++.h - include/grpc++/health_check_service_interface.h - include/grpc++/impl/call.h - include/grpc++/impl/channel_argument_option.h - include/grpc++/impl/client_unary_call.h - include/grpc++/impl/codegen/core_codegen.h - include/grpc++/impl/grpc_library.h - include/grpc++/impl/method_handler_impl.h - include/grpc++/impl/rpc_method.h - include/grpc++/impl/rpc_service_method.h - include/grpc++/impl/serialization_traits.h - include/grpc++/impl/server_builder_option.h - include/grpc++/impl/server_builder_plugin.h - include/grpc++/impl/server_initializer.h - include/grpc++/impl/service_type.h - include/grpc++/resource_quota.h - include/grpc++/security/auth_context.h - include/grpc++/security/auth_metadata_processor.h - include/grpc++/security/credentials.h - include/grpc++/security/server_credentials.h - include/grpc++/server.h - include/grpc++/server_builder.h - include/grpc++/server_context.h - include/grpc++/server_posix.h - include/grpc++/support/async_stream.h - include/grpc++/support/async_unary_call.h - include/grpc++/support/byte_buffer.h - include/grpc++/support/channel_arguments.h - include/grpc++/support/config.h - include/grpc++/support/slice.h - include/grpc++/support/status.h - include/grpc++/support/status_code_enum.h - include/grpc++/support/string_ref.h - include/grpc++/support/stub_options.h - include/grpc++/support/sync_stream.h - include/grpc++/support/time.h - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_unsecure EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc_benchmark - test/cpp/microbenchmarks/helpers.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_benchmark PROPERTIES COMPILE_PDB_NAME "grpc_benchmark" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_benchmark.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_benchmark - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_benchmark - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - benchmark - grpc++ - grpc_test_util - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(grpc_cli_libs - test/cpp/util/cli_call.cc - test/cpp/util/cli_credentials.cc - test/cpp/util/grpc_tool.cc - test/cpp/util/proto_file_parser.cc - test/cpp/util/service_describer.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_cli_libs PROPERTIES COMPILE_PDB_NAME "grpc_cli_libs" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_cli_libs.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/reflection/v1alpha/reflection.proto -) - -target_include_directories(grpc_cli_libs - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_cli_libs - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_proto_reflection_desc_db - grpc++ - grpc -) - -foreach(_hdr - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) - -add_library(grpc_plugin_support - src/compiler/cpp_generator.cc - src/compiler/csharp_generator.cc - src/compiler/node_generator.cc - src/compiler/objective_c_generator.cc - src/compiler/php_generator.cc - src/compiler/python_generator.cc - src/compiler/ruby_generator.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_plugin_support PROPERTIES COMPILE_PDB_NAME "grpc_plugin_support" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_plugin_support.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_plugin_support - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_plugin_support - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} -) - -foreach(_hdr - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc_plugin_support EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(http2_client_main - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/http2_client.cc -) - -if(WIN32 AND MSVC) - set_target_properties(http2_client_main PROPERTIES COMPILE_PDB_NAME "http2_client_main" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/http2_client_main.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(http2_client_main - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(http2_client_main - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - grpc++_test_config -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_client_helper - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - test/cpp/interop/client_helper.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_client_helper PROPERTIES COMPILE_PDB_NAME "interop_client_helper" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_client_helper.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) - -target_include_directories(interop_client_helper - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_client_helper - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_client_main - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/client.cc - test/cpp/interop/interop_client.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_client_main PROPERTIES COMPILE_PDB_NAME "interop_client_main" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_client_main.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(interop_client_main - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_client_main - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_client_helper - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_server_helper - test/cpp/interop/server_helper.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_server_helper PROPERTIES COMPILE_PDB_NAME "interop_server_helper" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_server_helper.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(interop_server_helper - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server_helper - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_server_lib - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/interop_server.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_server_lib PROPERTIES COMPILE_PDB_NAME "interop_server_lib" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_server_lib.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(interop_server_lib - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server_lib - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_server_helper - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_server_main - test/cpp/interop/interop_server_bootstrap.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_server_main PROPERTIES COMPILE_PDB_NAME "interop_server_main" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_server_main.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(interop_server_main - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server_main - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_server_lib -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(qps - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.h - test/cpp/qps/benchmark_config.cc - test/cpp/qps/client_async.cc - test/cpp/qps/client_sync.cc - test/cpp/qps/driver.cc - test/cpp/qps/parse_json.cc - test/cpp/qps/qps_worker.cc - test/cpp/qps/report.cc - test/cpp/qps/server_async.cc - test/cpp/qps/server_sync.cc - test/cpp/qps/usage_timer.cc -) - -if(WIN32 AND MSVC) - set_target_properties(qps PROPERTIES COMPILE_PDB_NAME "qps" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/qps.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/payloads.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/stats.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/control.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/services.proto -) - -target_include_directories(qps - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++_test_util - grpc++ - grpc -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc_csharp_ext SHARED - src/csharp/ext/grpc_csharp_ext.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_csharp_ext PROPERTIES COMPILE_PDB_NAME "grpc_csharp_ext" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_csharp_ext.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_csharp_ext - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_csharp_ext - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - - -if (gRPC_INSTALL) - install(TARGETS grpc_csharp_ext EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(ares - third_party/cares/cares/ares__close_sockets.c - third_party/cares/cares/ares__get_hostent.c - third_party/cares/cares/ares__read_line.c - third_party/cares/cares/ares__timeval.c - third_party/cares/cares/ares_cancel.c - third_party/cares/cares/ares_create_query.c - third_party/cares/cares/ares_data.c - third_party/cares/cares/ares_destroy.c - third_party/cares/cares/ares_expand_name.c - third_party/cares/cares/ares_expand_string.c - third_party/cares/cares/ares_fds.c - third_party/cares/cares/ares_free_hostent.c - third_party/cares/cares/ares_free_string.c - third_party/cares/cares/ares_getenv.c - third_party/cares/cares/ares_gethostbyaddr.c - third_party/cares/cares/ares_gethostbyname.c - third_party/cares/cares/ares_getnameinfo.c - third_party/cares/cares/ares_getopt.c - third_party/cares/cares/ares_getsock.c - third_party/cares/cares/ares_init.c - third_party/cares/cares/ares_library_init.c - third_party/cares/cares/ares_llist.c - third_party/cares/cares/ares_mkquery.c - third_party/cares/cares/ares_nowarn.c - third_party/cares/cares/ares_options.c - third_party/cares/cares/ares_parse_a_reply.c - third_party/cares/cares/ares_parse_aaaa_reply.c - third_party/cares/cares/ares_parse_mx_reply.c - third_party/cares/cares/ares_parse_naptr_reply.c - third_party/cares/cares/ares_parse_ns_reply.c - third_party/cares/cares/ares_parse_ptr_reply.c - third_party/cares/cares/ares_parse_soa_reply.c - third_party/cares/cares/ares_parse_srv_reply.c - third_party/cares/cares/ares_parse_txt_reply.c - third_party/cares/cares/ares_platform.c - third_party/cares/cares/ares_process.c - third_party/cares/cares/ares_query.c - third_party/cares/cares/ares_search.c - third_party/cares/cares/ares_send.c - third_party/cares/cares/ares_strcasecmp.c - third_party/cares/cares/ares_strdup.c - third_party/cares/cares/ares_strerror.c - third_party/cares/cares/ares_timeout.c - third_party/cares/cares/ares_version.c - third_party/cares/cares/ares_writev.c - third_party/cares/cares/bitncmp.c - third_party/cares/cares/inet_net_pton.c - third_party/cares/cares/inet_ntop.c - third_party/cares/cares/windows_port.c -) - -if(WIN32 AND MSVC) - set_target_properties(ares PROPERTIES COMPILE_PDB_NAME "ares" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ares.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(ares - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(ares - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(bad_client_test - test/core/bad_client/bad_client.c -) - -if(WIN32 AND MSVC) - set_target_properties(bad_client_test PROPERTIES COMPILE_PDB_NAME "bad_client_test" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/bad_client_test.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(bad_client_test - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_client_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(bad_ssl_test_server - test/core/bad_ssl/server_common.c -) - -if(WIN32 AND MSVC) - set_target_properties(bad_ssl_test_server PROPERTIES COMPILE_PDB_NAME "bad_ssl_test_server" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/bad_ssl_test_server.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(bad_ssl_test_server - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_ssl_test_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(end2end_tests - test/core/end2end/end2end_tests.c - test/core/end2end/end2end_test_utils.c - test/core/end2end/tests/authority_not_supported.c - test/core/end2end/tests/bad_hostname.c - test/core/end2end/tests/bad_ping.c - test/core/end2end/tests/binary_metadata.c - test/core/end2end/tests/call_creds.c - test/core/end2end/tests/cancel_after_accept.c - test/core/end2end/tests/cancel_after_client_done.c - test/core/end2end/tests/cancel_after_invoke.c - test/core/end2end/tests/cancel_after_round_trip.c - test/core/end2end/tests/cancel_before_invoke.c - test/core/end2end/tests/cancel_in_a_vacuum.c - test/core/end2end/tests/cancel_with_status.c - test/core/end2end/tests/compressed_payload.c - test/core/end2end/tests/connectivity.c - test/core/end2end/tests/default_host.c - test/core/end2end/tests/disappearing_server.c - test/core/end2end/tests/empty_batch.c - test/core/end2end/tests/filter_call_init_fails.c - test/core/end2end/tests/filter_causes_close.c - test/core/end2end/tests/filter_latency.c - test/core/end2end/tests/graceful_server_shutdown.c - test/core/end2end/tests/high_initial_seqno.c - test/core/end2end/tests/hpack_size.c - test/core/end2end/tests/idempotent_request.c - test/core/end2end/tests/invoke_large_request.c - test/core/end2end/tests/keepalive_timeout.c - test/core/end2end/tests/large_metadata.c - test/core/end2end/tests/load_reporting_hook.c - test/core/end2end/tests/max_concurrent_streams.c - test/core/end2end/tests/max_connection_age.c - test/core/end2end/tests/max_connection_idle.c - test/core/end2end/tests/max_message_length.c - test/core/end2end/tests/negative_deadline.c - test/core/end2end/tests/network_status_change.c - test/core/end2end/tests/no_logging.c - test/core/end2end/tests/no_op.c - test/core/end2end/tests/payload.c - test/core/end2end/tests/ping.c - test/core/end2end/tests/ping_pong_streaming.c - test/core/end2end/tests/proxy_auth.c - test/core/end2end/tests/registered_call.c - test/core/end2end/tests/request_with_flags.c - test/core/end2end/tests/request_with_payload.c - test/core/end2end/tests/resource_quota_server.c - test/core/end2end/tests/server_finishes_request.c - test/core/end2end/tests/shutdown_finishes_calls.c - test/core/end2end/tests/shutdown_finishes_tags.c - test/core/end2end/tests/simple_cacheable_request.c - test/core/end2end/tests/simple_delayed_request.c - test/core/end2end/tests/simple_metadata.c - test/core/end2end/tests/simple_request.c - test/core/end2end/tests/streaming_error_response.c - test/core/end2end/tests/trailing_metadata.c - test/core/end2end/tests/workaround_cronet_compression.c - test/core/end2end/tests/write_buffering.c - test/core/end2end/tests/write_buffering_at_end.c -) - -if(WIN32 AND MSVC) - set_target_properties(end2end_tests PROPERTIES COMPILE_PDB_NAME "end2end_tests" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/end2end_tests.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(end2end_tests - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(end2end_tests - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(end2end_nosec_tests - test/core/end2end/end2end_nosec_tests.c - test/core/end2end/end2end_test_utils.c - test/core/end2end/tests/authority_not_supported.c - test/core/end2end/tests/bad_hostname.c - test/core/end2end/tests/bad_ping.c - test/core/end2end/tests/binary_metadata.c - test/core/end2end/tests/cancel_after_accept.c - test/core/end2end/tests/cancel_after_client_done.c - test/core/end2end/tests/cancel_after_invoke.c - test/core/end2end/tests/cancel_after_round_trip.c - test/core/end2end/tests/cancel_before_invoke.c - test/core/end2end/tests/cancel_in_a_vacuum.c - test/core/end2end/tests/cancel_with_status.c - test/core/end2end/tests/compressed_payload.c - test/core/end2end/tests/connectivity.c - test/core/end2end/tests/default_host.c - test/core/end2end/tests/disappearing_server.c - test/core/end2end/tests/empty_batch.c - test/core/end2end/tests/filter_call_init_fails.c - test/core/end2end/tests/filter_causes_close.c - test/core/end2end/tests/filter_latency.c - test/core/end2end/tests/graceful_server_shutdown.c - test/core/end2end/tests/high_initial_seqno.c - test/core/end2end/tests/hpack_size.c - test/core/end2end/tests/idempotent_request.c - test/core/end2end/tests/invoke_large_request.c - test/core/end2end/tests/keepalive_timeout.c - test/core/end2end/tests/large_metadata.c - test/core/end2end/tests/load_reporting_hook.c - test/core/end2end/tests/max_concurrent_streams.c - test/core/end2end/tests/max_connection_age.c - test/core/end2end/tests/max_connection_idle.c - test/core/end2end/tests/max_message_length.c - test/core/end2end/tests/negative_deadline.c - test/core/end2end/tests/network_status_change.c - test/core/end2end/tests/no_logging.c - test/core/end2end/tests/no_op.c - test/core/end2end/tests/payload.c - test/core/end2end/tests/ping.c - test/core/end2end/tests/ping_pong_streaming.c - test/core/end2end/tests/proxy_auth.c - test/core/end2end/tests/registered_call.c - test/core/end2end/tests/request_with_flags.c - test/core/end2end/tests/request_with_payload.c - test/core/end2end/tests/resource_quota_server.c - test/core/end2end/tests/server_finishes_request.c - test/core/end2end/tests/shutdown_finishes_calls.c - test/core/end2end/tests/shutdown_finishes_tags.c - test/core/end2end/tests/simple_cacheable_request.c - test/core/end2end/tests/simple_delayed_request.c - test/core/end2end/tests/simple_metadata.c - test/core/end2end/tests/simple_request.c - test/core/end2end/tests/streaming_error_response.c - test/core/end2end/tests/trailing_metadata.c - test/core/end2end/tests/workaround_cronet_compression.c - test/core/end2end/tests/write_buffering.c - test/core/end2end/tests/write_buffering_at_end.c -) - -if(WIN32 AND MSVC) - set_target_properties(end2end_nosec_tests PROPERTIES COMPILE_PDB_NAME "end2end_nosec_tests" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/end2end_nosec_tests.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(end2end_nosec_tests - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(end2end_nosec_tests - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) - -if (gRPC_BUILD_TESTS) - -add_executable(alarm_test - test/core/surface/alarm_test.c -) - - -target_include_directories(alarm_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(alarm_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(algorithm_test - test/core/compression/algorithm_test.c -) - - -target_include_directories(algorithm_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(algorithm_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(alloc_test - test/core/support/alloc_test.c -) - - -target_include_directories(alloc_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(alloc_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(alpn_test - test/core/transport/chttp2/alpn_test.c -) - - -target_include_directories(alpn_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(alpn_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(arena_test - test/core/support/arena_test.c -) - - -target_include_directories(arena_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(arena_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bad_server_response_test - test/core/end2end/bad_server_response_test.c -) - - -target_include_directories(bad_server_response_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_server_response_test - ${_gRPC_ALLTARGETS_LIBRARIES} - test_tcp_server - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bdp_estimator_test - test/core/transport/bdp_estimator_test.c -) - - -target_include_directories(bdp_estimator_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bdp_estimator_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bin_decoder_test - test/core/transport/chttp2/bin_decoder_test.c -) - - -target_include_directories(bin_decoder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bin_decoder_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bin_encoder_test - test/core/transport/chttp2/bin_encoder_test.c -) - - -target_include_directories(bin_encoder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bin_encoder_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_context_test - test/core/census/context_test.c -) - - -target_include_directories(census_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_context_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_intrusive_hash_map_test - test/core/census/intrusive_hash_map_test.c -) - - -target_include_directories(census_intrusive_hash_map_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_intrusive_hash_map_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_resource_test - test/core/census/resource_test.c -) - - -target_include_directories(census_resource_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_resource_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_trace_context_test - test/core/census/trace_context_test.c -) - - -target_include_directories(census_trace_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_trace_context_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(channel_create_test - test/core/surface/channel_create_test.c -) - - -target_include_directories(channel_create_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(channel_create_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(check_epollexclusive - test/build/check_epollexclusive.c -) - - -target_include_directories(check_epollexclusive - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(check_epollexclusive - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS check_epollexclusive EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(chttp2_hpack_encoder_test - test/core/transport/chttp2/hpack_encoder_test.c -) - - -target_include_directories(chttp2_hpack_encoder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(chttp2_hpack_encoder_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(chttp2_stream_map_test - test/core/transport/chttp2/stream_map_test.c -) - - -target_include_directories(chttp2_stream_map_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(chttp2_stream_map_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(chttp2_varint_test - test/core/transport/chttp2/varint_test.c -) - - -target_include_directories(chttp2_varint_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(chttp2_varint_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(combiner_test - test/core/iomgr/combiner_test.c -) - - -target_include_directories(combiner_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(combiner_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(compression_test - test/core/compression/compression_test.c -) - - -target_include_directories(compression_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(compression_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(concurrent_connectivity_test - test/core/surface/concurrent_connectivity_test.c -) - - -target_include_directories(concurrent_connectivity_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(concurrent_connectivity_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(connection_refused_test - test/core/end2end/connection_refused_test.c -) - - -target_include_directories(connection_refused_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(connection_refused_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(dns_resolver_connectivity_test - test/core/client_channel/resolvers/dns_resolver_connectivity_test.c -) - - -target_include_directories(dns_resolver_connectivity_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(dns_resolver_connectivity_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(dns_resolver_test - test/core/client_channel/resolvers/dns_resolver_test.c -) - - -target_include_directories(dns_resolver_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(dns_resolver_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(dualstack_socket_test - test/core/end2end/dualstack_socket_test.c -) - - -target_include_directories(dualstack_socket_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(dualstack_socket_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(endpoint_pair_test - test/core/iomgr/endpoint_pair_test.c -) - - -target_include_directories(endpoint_pair_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(endpoint_pair_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(error_test - test/core/iomgr/error_test.c -) - - -target_include_directories(error_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(error_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(ev_epollsig_linux_test - test/core/iomgr/ev_epollsig_linux_test.c -) - - -target_include_directories(ev_epollsig_linux_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(ev_epollsig_linux_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(fake_resolver_test - test/core/client_channel/resolvers/fake_resolver_test.c -) - - -target_include_directories(fake_resolver_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fake_resolver_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fd_conservation_posix_test - test/core/iomgr/fd_conservation_posix_test.c -) - - -target_include_directories(fd_conservation_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fd_conservation_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fd_posix_test - test/core/iomgr/fd_posix_test.c -) - - -target_include_directories(fd_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fd_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(fling_client - test/core/fling/client.c -) - - -target_include_directories(fling_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_client - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(fling_server - test/core/fling/server.c -) - - -target_include_directories(fling_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fling_stream_test - test/core/fling/fling_stream_test.c -) - - -target_include_directories(fling_stream_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_stream_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fling_test - test/core/fling/fling_test.c -) - - -target_include_directories(fling_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) - -add_executable(gen_hpack_tables - tools/codegen/core/gen_hpack_tables.c -) - - -target_include_directories(gen_hpack_tables - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gen_hpack_tables - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - grpc -) - - -if (gRPC_INSTALL) - install(TARGETS gen_hpack_tables EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(gen_legal_metadata_characters - tools/codegen/core/gen_legal_metadata_characters.c -) - - -target_include_directories(gen_legal_metadata_characters - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gen_legal_metadata_characters - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -if (gRPC_INSTALL) - install(TARGETS gen_legal_metadata_characters EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(gen_percent_encoding_tables - tools/codegen/core/gen_percent_encoding_tables.c -) - - -target_include_directories(gen_percent_encoding_tables - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gen_percent_encoding_tables - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -if (gRPC_INSTALL) - install(TARGETS gen_percent_encoding_tables EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(goaway_server_test - test/core/end2end/goaway_server_test.c -) - - -target_include_directories(goaway_server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(goaway_server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_avl_test - test/core/support/avl_test.c -) - - -target_include_directories(gpr_avl_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_avl_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_backoff_test - test/core/support/backoff_test.c -) - - -target_include_directories(gpr_backoff_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_backoff_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_cmdline_test - test/core/support/cmdline_test.c -) - - -target_include_directories(gpr_cmdline_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_cmdline_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_cpu_test - test/core/support/cpu_test.c -) - - -target_include_directories(gpr_cpu_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_cpu_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_env_test - test/core/support/env_test.c -) - - -target_include_directories(gpr_env_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_env_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_histogram_test - test/core/support/histogram_test.c -) - - -target_include_directories(gpr_histogram_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_histogram_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_host_port_test - test/core/support/host_port_test.c -) - - -target_include_directories(gpr_host_port_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_host_port_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_log_test - test/core/support/log_test.c -) - - -target_include_directories(gpr_log_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_log_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_mpscq_test - test/core/support/mpscq_test.c -) - - -target_include_directories(gpr_mpscq_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_mpscq_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_spinlock_test - test/core/support/spinlock_test.c -) - - -target_include_directories(gpr_spinlock_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_spinlock_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_stack_lockfree_test - test/core/support/stack_lockfree_test.c -) - - -target_include_directories(gpr_stack_lockfree_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_stack_lockfree_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_string_test - test/core/support/string_test.c -) - - -target_include_directories(gpr_string_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_string_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_sync_test - test/core/support/sync_test.c -) - - -target_include_directories(gpr_sync_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_sync_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_thd_test - test/core/support/thd_test.c -) - - -target_include_directories(gpr_thd_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_thd_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_time_test - test/core/support/time_test.c -) - - -target_include_directories(gpr_time_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_time_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_tls_test - test/core/support/tls_test.c -) - - -target_include_directories(gpr_tls_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_tls_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_useful_test - test/core/support/useful_test.c -) - - -target_include_directories(gpr_useful_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_useful_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_auth_context_test - test/core/security/auth_context_test.c -) - - -target_include_directories(grpc_auth_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_auth_context_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_b64_test - test/core/slice/b64_test.c -) - - -target_include_directories(grpc_b64_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_b64_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_byte_buffer_reader_test - test/core/surface/byte_buffer_reader_test.c -) - - -target_include_directories(grpc_byte_buffer_reader_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_byte_buffer_reader_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_channel_args_test - test/core/channel/channel_args_test.c -) - - -target_include_directories(grpc_channel_args_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_channel_args_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_channel_stack_test - test/core/channel/channel_stack_test.c -) - - -target_include_directories(grpc_channel_stack_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_channel_stack_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_completion_queue_test - test/core/surface/completion_queue_test.c -) - - -target_include_directories(grpc_completion_queue_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_completion_queue_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_completion_queue_threading_test - test/core/surface/completion_queue_threading_test.c -) - - -target_include_directories(grpc_completion_queue_threading_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_completion_queue_threading_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_create_jwt - test/core/security/create_jwt.c -) - - -target_include_directories(grpc_create_jwt - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_create_jwt - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_create_jwt EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(grpc_credentials_test - test/core/security/credentials_test.c -) - - -target_include_directories(grpc_credentials_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_credentials_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_fetch_oauth2 - test/core/security/fetch_oauth2.c -) - - -target_include_directories(grpc_fetch_oauth2 - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_fetch_oauth2 - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_invalid_channel_args_test - test/core/surface/invalid_channel_args_test.c -) - - -target_include_directories(grpc_invalid_channel_args_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_invalid_channel_args_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(grpc_json_token_test - test/core/security/json_token_test.c -) - - -target_include_directories(grpc_json_token_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_json_token_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_jwt_verifier_test - test/core/security/jwt_verifier_test.c -) - - -target_include_directories(grpc_jwt_verifier_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_jwt_verifier_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_print_google_default_creds_token - test/core/security/print_google_default_creds_token.c -) - - -target_include_directories(grpc_print_google_default_creds_token - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_print_google_default_creds_token - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_print_google_default_creds_token EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(grpc_security_connector_test - test/core/security/security_connector_test.c -) - - -target_include_directories(grpc_security_connector_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_security_connector_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_verify_jwt - test/core/security/verify_jwt.c -) - - -target_include_directories(grpc_verify_jwt - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_verify_jwt - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_verify_jwt EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(handshake_client - test/core/handshake/client_ssl.c -) - - -target_include_directories(handshake_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(handshake_client - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(handshake_server - test/core/handshake/server_ssl.c -) - - -target_include_directories(handshake_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(handshake_server - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hpack_parser_test - test/core/transport/chttp2/hpack_parser_test.c -) - - -target_include_directories(hpack_parser_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(hpack_parser_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hpack_table_test - test/core/transport/chttp2/hpack_table_test.c -) - - -target_include_directories(hpack_table_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(hpack_table_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(http_parser_test - test/core/http/parser_test.c -) - - -target_include_directories(http_parser_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(http_parser_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(httpcli_format_request_test - test/core/http/format_request_test.c -) - - -target_include_directories(httpcli_format_request_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(httpcli_format_request_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(httpcli_test - test/core/http/httpcli_test.c -) - - -target_include_directories(httpcli_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(httpcli_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(httpscli_test - test/core/http/httpscli_test.c -) - - -target_include_directories(httpscli_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(httpscli_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(init_test - test/core/surface/init_test.c -) - - -target_include_directories(init_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(init_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(invalid_call_argument_test - test/core/end2end/invalid_call_argument_test.c -) - - -target_include_directories(invalid_call_argument_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(invalid_call_argument_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_rewrite - test/core/json/json_rewrite.c -) - - -target_include_directories(json_rewrite - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_rewrite - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_rewrite_test - test/core/json/json_rewrite_test.c -) - - -target_include_directories(json_rewrite_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_rewrite_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_stream_error_test - test/core/json/json_stream_error_test.c -) - - -target_include_directories(json_stream_error_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_stream_error_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_test - test/core/json/json_test.c -) - - -target_include_directories(json_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(lame_client_test - test/core/surface/lame_client_test.c -) - - -target_include_directories(lame_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(lame_client_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(lb_policies_test - test/core/client_channel/lb_policies_test.c -) - - -target_include_directories(lb_policies_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(lb_policies_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(load_file_test - test/core/iomgr/load_file_test.c -) - - -target_include_directories(load_file_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(load_file_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(memory_profile_client - test/core/memory_usage/client.c -) - - -target_include_directories(memory_profile_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(memory_profile_client - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(memory_profile_server - test/core/memory_usage/server.c -) - - -target_include_directories(memory_profile_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(memory_profile_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(memory_profile_test - test/core/memory_usage/memory_usage_test.c -) - - -target_include_directories(memory_profile_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(memory_profile_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(message_compress_test - test/core/compression/message_compress_test.c -) - - -target_include_directories(message_compress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(message_compress_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(minimal_stack_is_minimal_test - test/core/channel/minimal_stack_is_minimal_test.c -) - - -target_include_directories(minimal_stack_is_minimal_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(minimal_stack_is_minimal_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(mlog_test - test/core/census/mlog_test.c -) - - -target_include_directories(mlog_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(mlog_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(multiple_server_queues_test - test/core/end2end/multiple_server_queues_test.c -) - - -target_include_directories(multiple_server_queues_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(multiple_server_queues_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(murmur_hash_test - test/core/support/murmur_hash_test.c -) - - -target_include_directories(murmur_hash_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(murmur_hash_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(no_server_test - test/core/end2end/no_server_test.c -) - - -target_include_directories(no_server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(no_server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(num_external_connectivity_watchers_test - test/core/surface/num_external_connectivity_watchers_test.c -) - - -target_include_directories(num_external_connectivity_watchers_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(num_external_connectivity_watchers_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(parse_address_test - test/core/client_channel/parse_address_test.c -) - - -target_include_directories(parse_address_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(parse_address_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(percent_encoding_test - test/core/slice/percent_encoding_test.c -) - - -target_include_directories(percent_encoding_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(percent_encoding_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(pollset_set_test - test/core/iomgr/pollset_set_test.c -) - - -target_include_directories(pollset_set_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(pollset_set_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(resolve_address_posix_test - test/core/iomgr/resolve_address_posix_test.c -) - - -target_include_directories(resolve_address_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(resolve_address_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(resolve_address_test - test/core/iomgr/resolve_address_test.c -) - - -target_include_directories(resolve_address_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(resolve_address_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(resource_quota_test - test/core/iomgr/resource_quota_test.c -) - - -target_include_directories(resource_quota_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(resource_quota_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(secure_channel_create_test - test/core/surface/secure_channel_create_test.c -) - - -target_include_directories(secure_channel_create_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(secure_channel_create_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(secure_endpoint_test - test/core/security/secure_endpoint_test.c -) - - -target_include_directories(secure_endpoint_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(secure_endpoint_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(sequential_connectivity_test - test/core/surface/sequential_connectivity_test.c -) - - -target_include_directories(sequential_connectivity_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(sequential_connectivity_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_chttp2_test - test/core/surface/server_chttp2_test.c -) - - -target_include_directories(server_chttp2_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_chttp2_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_test - test/core/surface/server_test.c -) - - -target_include_directories(server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_buffer_test - test/core/slice/slice_buffer_test.c -) - - -target_include_directories(slice_buffer_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_buffer_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_hash_table_test - test/core/slice/slice_hash_table_test.c -) - - -target_include_directories(slice_hash_table_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_hash_table_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_string_helpers_test - test/core/slice/slice_string_helpers_test.c -) - - -target_include_directories(slice_string_helpers_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_string_helpers_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_test - test/core/slice/slice_test.c -) - - -target_include_directories(slice_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(sockaddr_resolver_test - test/core/client_channel/resolvers/sockaddr_resolver_test.c -) - - -target_include_directories(sockaddr_resolver_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(sockaddr_resolver_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(sockaddr_utils_test - test/core/iomgr/sockaddr_utils_test.c -) - - -target_include_directories(sockaddr_utils_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(sockaddr_utils_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(socket_utils_test - test/core/iomgr/socket_utils_test.c -) - - -target_include_directories(socket_utils_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(socket_utils_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(status_conversion_test - test/core/transport/status_conversion_test.c -) - - -target_include_directories(status_conversion_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(status_conversion_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(stream_compression_test - test/core/compression/stream_compression_test.c -) - - -target_include_directories(stream_compression_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(stream_compression_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(stream_owned_slice_test - test/core/transport/stream_owned_slice_test.c -) - - -target_include_directories(stream_owned_slice_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(stream_owned_slice_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(tcp_client_posix_test - test/core/iomgr/tcp_client_posix_test.c -) - - -target_include_directories(tcp_client_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_client_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(tcp_client_uv_test - test/core/iomgr/tcp_client_uv_test.c -) - - -target_include_directories(tcp_client_uv_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_client_uv_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(tcp_posix_test - test/core/iomgr/tcp_posix_test.c -) - - -target_include_directories(tcp_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(tcp_server_posix_test - test/core/iomgr/tcp_server_posix_test.c -) - - -target_include_directories(tcp_server_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_server_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(tcp_server_uv_test - test/core/iomgr/tcp_server_uv_test.c -) - - -target_include_directories(tcp_server_uv_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_server_uv_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(time_averaged_stats_test - test/core/iomgr/time_averaged_stats_test.c -) - - -target_include_directories(time_averaged_stats_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(time_averaged_stats_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(timeout_encoding_test - test/core/transport/timeout_encoding_test.c -) - - -target_include_directories(timeout_encoding_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(timeout_encoding_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(timer_heap_test - test/core/iomgr/timer_heap_test.c -) - - -target_include_directories(timer_heap_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(timer_heap_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(timer_list_test - test/core/iomgr/timer_list_test.c -) - - -target_include_directories(timer_list_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(timer_list_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(transport_connectivity_state_test - test/core/transport/connectivity_state_test.c -) - - -target_include_directories(transport_connectivity_state_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_connectivity_state_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(transport_metadata_test - test/core/transport/metadata_test.c -) - - -target_include_directories(transport_metadata_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_metadata_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(transport_pid_controller_test - test/core/transport/pid_controller_test.c -) - - -target_include_directories(transport_pid_controller_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_pid_controller_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(transport_security_test - test/core/tsi/transport_security_test.c -) - - -target_include_directories(transport_security_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_security_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(udp_server_test - test/core/iomgr/udp_server_test.c -) - - -target_include_directories(udp_server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(udp_server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(uri_parser_test - test/core/client_channel/uri_parser_test.c -) - - -target_include_directories(uri_parser_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(uri_parser_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(wakeup_fd_cv_test - test/core/iomgr/wakeup_fd_cv_test.c -) - - -target_include_directories(wakeup_fd_cv_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(wakeup_fd_cv_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(alarm_cpp_test - test/cpp/common/alarm_cpp_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(alarm_cpp_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(alarm_cpp_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(async_end2end_test - test/cpp/end2end/async_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(async_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(async_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(auth_property_iterator_test - test/cpp/common/auth_property_iterator_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(auth_property_iterator_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(auth_property_iterator_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_arena - test/cpp/microbenchmarks/bm_arena.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_arena - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_arena - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_call_create - test/cpp/microbenchmarks/bm_call_create.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_call_create - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_call_create - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_chttp2_hpack - test/cpp/microbenchmarks/bm_chttp2_hpack.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_chttp2_hpack - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_chttp2_hpack - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_chttp2_transport - test/cpp/microbenchmarks/bm_chttp2_transport.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_chttp2_transport - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_chttp2_transport - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_closure - test/cpp/microbenchmarks/bm_closure.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_closure - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_closure - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_cq - test/cpp/microbenchmarks/bm_cq.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_cq - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_cq - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_cq_multiple_threads - test/cpp/microbenchmarks/bm_cq_multiple_threads.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_cq_multiple_threads - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_cq_multiple_threads - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_error - test/cpp/microbenchmarks/bm_error.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_error - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_error - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_streaming_ping_pong - test/cpp/microbenchmarks/bm_fullstack_streaming_ping_pong.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_streaming_ping_pong - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_streaming_ping_pong - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_streaming_pump - test/cpp/microbenchmarks/bm_fullstack_streaming_pump.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_streaming_pump - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_streaming_pump - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_trickle - test/cpp/microbenchmarks/bm_fullstack_trickle.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_trickle - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_trickle - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_unary_ping_pong - test/cpp/microbenchmarks/bm_fullstack_unary_ping_pong.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_unary_ping_pong - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_unary_ping_pong - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_metadata - test/cpp/microbenchmarks/bm_metadata.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_metadata - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_metadata - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_pollset - test/cpp/microbenchmarks/bm_pollset.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_pollset - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_pollset - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(channel_arguments_test - test/cpp/common/channel_arguments_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(channel_arguments_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(channel_arguments_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(channel_filter_test - test/cpp/common/channel_filter_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(channel_filter_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(channel_filter_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cli_call_test - test/cpp/util/cli_call_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cli_call_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cli_call_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_cli_libs - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(client_crash_test - test/cpp/end2end/client_crash_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(client_crash_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(client_crash_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(client_crash_test_server - test/cpp/end2end/client_crash_test_server.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(client_crash_test_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(client_crash_test_server - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(client_lb_end2end_test - test/cpp/end2end/client_lb_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(client_lb_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(client_lb_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(codegen_test_full - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.h - test/cpp/codegen/codegen_test_full.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/control.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/payloads.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/services.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/stats.proto -) - -target_include_directories(codegen_test_full - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(codegen_test_full - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(codegen_test_minimal - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.h - test/cpp/codegen/codegen_test_minimal.cc - src/cpp/codegen/codegen_init.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/control.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/payloads.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/services.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/stats.proto -) - -target_include_directories(codegen_test_minimal - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(codegen_test_minimal - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(credentials_test - test/cpp/client/credentials_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(credentials_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(credentials_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_byte_buffer_test - test/cpp/util/byte_buffer_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_byte_buffer_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_byte_buffer_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_slice_test - test/cpp/util/slice_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_slice_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_slice_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_string_ref_test - test/cpp/util/string_ref_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_string_ref_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_string_ref_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_time_test - test/cpp/util/time_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_time_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_time_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(end2end_test - test/cpp/end2end/end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(error_details_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - test/cpp/util/error_details_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) - -target_include_directories(error_details_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(error_details_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_error_details - grpc++ - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(filter_end2end_test - test/cpp/end2end/filter_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(filter_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(filter_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(generic_end2end_test - test/cpp/end2end/generic_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(generic_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(generic_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(golden_file_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.grpc.pb.h - test/cpp/codegen/golden_file_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/compiler_test.proto -) - -target_include_directories(golden_file_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(golden_file_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_cli - test/cpp/util/grpc_cli.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(grpc_cli - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_cli - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_cli_libs - grpc++_proto_reflection_desc_db - grpc++ - grpc - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_cpp_plugin - src/compiler/cpp_plugin.cc -) - - -target_include_directories(grpc_cpp_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_cpp_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_cpp_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_csharp_plugin - src/compiler/csharp_plugin.cc -) - - -target_include_directories(grpc_csharp_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_csharp_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_csharp_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_node_plugin - src/compiler/node_plugin.cc -) - - -target_include_directories(grpc_node_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_node_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_node_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_objective_c_plugin - src/compiler/objective_c_plugin.cc -) - - -target_include_directories(grpc_objective_c_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_objective_c_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_objective_c_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_php_plugin - src/compiler/php_plugin.cc -) - - -target_include_directories(grpc_php_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_php_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_php_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_python_plugin - src/compiler/python_plugin.cc -) - - -target_include_directories(grpc_python_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_python_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_python_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_ruby_plugin - src/compiler/ruby_plugin.cc -) - - -target_include_directories(grpc_ruby_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_ruby_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_ruby_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(grpc_tool_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - test/cpp/util/grpc_tool_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) - -target_include_directories(grpc_tool_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_tool_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_cli_libs - grpc++_proto_reflection_desc_db - grpc++_reflection - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpclb_api_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.h - test/cpp/grpclb/grpclb_api_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/lb/v1/load_balancer.proto -) - -target_include_directories(grpclb_api_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpclb_api_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpclb_end2end_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.h - test/cpp/end2end/grpclb_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/lb/v1/load_balancer.proto -) - -target_include_directories(grpclb_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpclb_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpclb_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.h - test/cpp/grpclb/grpclb_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/lb/v1/load_balancer.proto -) - -target_include_directories(grpclb_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpclb_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(health_service_end2end_test - test/cpp/end2end/health_service_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(health_service_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(health_service_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(http2_client - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(http2_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(http2_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - http2_client_main - grpc++_test_util - grpc_test_util - grpc++ - grpc - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hybrid_end2end_test - test/cpp/end2end/hybrid_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(hybrid_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(hybrid_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(interop_client - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(interop_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_client_main - interop_client_helper - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(interop_server - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(interop_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_server_main - interop_server_helper - interop_server_lib - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(interop_test - test/cpp/interop/interop_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(interop_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(json_run_localhost - test/cpp/qps/json_run_localhost.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(json_run_localhost - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(json_run_localhost - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(memory_test - test/core/support/memory_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(memory_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(memory_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(metrics_client - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.h - test/cpp/interop/metrics_client.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/metrics.proto -) - -target_include_directories(metrics_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(metrics_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(mock_test - test/cpp/end2end/mock_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(mock_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(mock_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(noop-benchmark - test/cpp/microbenchmarks/noop-benchmark.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(noop-benchmark - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(noop-benchmark - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - benchmark - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(proto_server_reflection_test - test/cpp/end2end/proto_server_reflection_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(proto_server_reflection_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(proto_server_reflection_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_proto_reflection_desc_db - grpc++_reflection - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(proto_utils_test - test/cpp/codegen/proto_utils_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(proto_utils_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(proto_utils_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(qps_interarrival_test - test/cpp/qps/qps_interarrival_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_interarrival_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_interarrival_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(qps_json_driver - test/cpp/qps/qps_json_driver.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_json_driver - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_json_driver - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(qps_openloop_test - test/cpp/qps/qps_openloop_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_openloop_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_openloop_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(qps_worker - test/cpp/qps/worker.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_worker - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_worker - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(reconnect_interop_client - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/reconnect_interop_client.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(reconnect_interop_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(reconnect_interop_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(reconnect_interop_server - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/reconnect_interop_server.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(reconnect_interop_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(reconnect_interop_server - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - reconnect_server - test_tcp_server - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(secure_auth_context_test - test/cpp/common/secure_auth_context_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(secure_auth_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(secure_auth_context_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(secure_sync_unary_ping_pong_test - test/cpp/qps/secure_sync_unary_ping_pong_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(secure_sync_unary_ping_pong_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(secure_sync_unary_ping_pong_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_builder_plugin_test - test/cpp/end2end/server_builder_plugin_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_builder_plugin_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_builder_plugin_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_builder_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - test/cpp/server/server_builder_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) - -target_include_directories(server_builder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_builder_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - gpr_test_util - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_context_test_spouse_test - test/cpp/test/server_context_test_spouse_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_context_test_spouse_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_context_test_spouse_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(server_crash_test - test/cpp/end2end/server_crash_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_crash_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_crash_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_crash_test_client - test/cpp/end2end/server_crash_test_client.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_crash_test_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_crash_test_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_request_call_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - test/cpp/server/server_request_call_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) - -target_include_directories(server_request_call_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_request_call_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - gpr_test_util - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(shutdown_test - test/cpp/end2end/shutdown_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(shutdown_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(shutdown_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(status_test - test/cpp/util/status_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(status_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(status_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(streaming_throughput_test - test/cpp/end2end/streaming_throughput_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(streaming_throughput_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(streaming_throughput_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(stress_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/interop_client.cc - test/cpp/interop/stress_interop_client.cc - test/cpp/interop/stress_test.cc - test/cpp/util/metrics_server.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/metrics.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(stress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(stress_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(thread_manager_test - test/cpp/thread_manager/thread_manager_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(thread_manager_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(thread_manager_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(thread_stress_test - test/cpp/end2end/thread_stress_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(thread_stress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(thread_stress_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(writes_per_rpc_test - test/cpp/performance/writes_per_rpc_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(writes_per_rpc_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(writes_per_rpc_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(public_headers_must_be_c89 - test/core/surface/public_headers_must_be_c89.c -) - - -target_include_directories(public_headers_must_be_c89 - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(public_headers_must_be_c89 - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(badreq_bad_client_test - test/core/bad_client/tests/badreq.c -) - - -target_include_directories(badreq_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(badreq_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(connection_prefix_bad_client_test - test/core/bad_client/tests/connection_prefix.c -) - - -target_include_directories(connection_prefix_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(connection_prefix_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(head_of_line_blocking_bad_client_test - test/core/bad_client/tests/head_of_line_blocking.c -) - - -target_include_directories(head_of_line_blocking_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(head_of_line_blocking_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(headers_bad_client_test - test/core/bad_client/tests/headers.c -) - - -target_include_directories(headers_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(headers_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(initial_settings_frame_bad_client_test - test/core/bad_client/tests/initial_settings_frame.c -) - - -target_include_directories(initial_settings_frame_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(initial_settings_frame_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(large_metadata_bad_client_test - test/core/bad_client/tests/large_metadata.c -) - - -target_include_directories(large_metadata_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(large_metadata_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_registered_method_bad_client_test - test/core/bad_client/tests/server_registered_method.c -) - - -target_include_directories(server_registered_method_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_registered_method_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(simple_request_bad_client_test - test/core/bad_client/tests/simple_request.c -) - - -target_include_directories(simple_request_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(simple_request_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(unknown_frame_bad_client_test - test/core/bad_client/tests/unknown_frame.c -) - - -target_include_directories(unknown_frame_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(unknown_frame_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(window_overflow_bad_client_test - test/core/bad_client/tests/window_overflow.c -) - - -target_include_directories(window_overflow_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(window_overflow_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bad_ssl_cert_server - test/core/bad_ssl/servers/cert.c -) - - -target_include_directories(bad_ssl_cert_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_ssl_cert_server - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_ssl_test_server - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bad_ssl_cert_test - test/core/bad_ssl/bad_ssl_test.c -) - - -target_include_directories(bad_ssl_cert_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_ssl_cert_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_census_test - test/core/end2end/fixtures/h2_census.c -) - - -target_include_directories(h2_census_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_census_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_compress_test - test/core/end2end/fixtures/h2_compress.c -) - - -target_include_directories(h2_compress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_compress_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_fakesec_test - test/core/end2end/fixtures/h2_fakesec.c -) - - -target_include_directories(h2_fakesec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_fakesec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_fd_test - test/core/end2end/fixtures/h2_fd.c -) - - -target_include_directories(h2_fd_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_fd_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full_test - test/core/end2end/fixtures/h2_full.c -) - - -target_include_directories(h2_full_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(h2_full+pipe_test - test/core/end2end/fixtures/h2_full+pipe.c -) - - -target_include_directories(h2_full+pipe_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+pipe_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+trace_test - test/core/end2end/fixtures/h2_full+trace.c -) - - -target_include_directories(h2_full+trace_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+trace_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+workarounds_test - test/core/end2end/fixtures/h2_full+workarounds.c -) - - -target_include_directories(h2_full+workarounds_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+workarounds_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_http_proxy_test - test/core/end2end/fixtures/h2_http_proxy.c -) - - -target_include_directories(h2_http_proxy_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_http_proxy_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_load_reporting_test - test/core/end2end/fixtures/h2_load_reporting.c -) - - -target_include_directories(h2_load_reporting_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_load_reporting_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_oauth2_test - test/core/end2end/fixtures/h2_oauth2.c -) - - -target_include_directories(h2_oauth2_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_oauth2_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_proxy_test - test/core/end2end/fixtures/h2_proxy.c -) - - -target_include_directories(h2_proxy_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_proxy_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_test - test/core/end2end/fixtures/h2_sockpair.c -) - - -target_include_directories(h2_sockpair_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair+trace_test - test/core/end2end/fixtures/h2_sockpair+trace.c -) - - -target_include_directories(h2_sockpair+trace_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair+trace_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_1byte_test - test/core/end2end/fixtures/h2_sockpair_1byte.c -) - - -target_include_directories(h2_sockpair_1byte_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_1byte_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_ssl_test - test/core/end2end/fixtures/h2_ssl.c -) - - -target_include_directories(h2_ssl_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_ssl_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_ssl_cert_test - test/core/end2end/fixtures/h2_ssl_cert.c -) - - -target_include_directories(h2_ssl_cert_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_ssl_cert_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_ssl_proxy_test - test/core/end2end/fixtures/h2_ssl_proxy.c -) - - -target_include_directories(h2_ssl_proxy_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_ssl_proxy_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_uds_test - test/core/end2end/fixtures/h2_uds.c -) - - -target_include_directories(h2_uds_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_uds_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(inproc_test - test/core/end2end/fixtures/inproc.c -) - - -target_include_directories(inproc_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(inproc_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_census_nosec_test - test/core/end2end/fixtures/h2_census.c -) - - -target_include_directories(h2_census_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_census_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_compress_nosec_test - test/core/end2end/fixtures/h2_compress.c -) - - -target_include_directories(h2_compress_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_compress_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_fd_nosec_test - test/core/end2end/fixtures/h2_fd.c -) - - -target_include_directories(h2_fd_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_fd_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full_nosec_test - test/core/end2end/fixtures/h2_full.c -) - - -target_include_directories(h2_full_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(h2_full+pipe_nosec_test - test/core/end2end/fixtures/h2_full+pipe.c -) - - -target_include_directories(h2_full+pipe_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+pipe_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+trace_nosec_test - test/core/end2end/fixtures/h2_full+trace.c -) - - -target_include_directories(h2_full+trace_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+trace_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+workarounds_nosec_test - test/core/end2end/fixtures/h2_full+workarounds.c -) - - -target_include_directories(h2_full+workarounds_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+workarounds_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_http_proxy_nosec_test - test/core/end2end/fixtures/h2_http_proxy.c -) - - -target_include_directories(h2_http_proxy_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_http_proxy_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_load_reporting_nosec_test - test/core/end2end/fixtures/h2_load_reporting.c -) - - -target_include_directories(h2_load_reporting_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_load_reporting_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_proxy_nosec_test - test/core/end2end/fixtures/h2_proxy.c -) - - -target_include_directories(h2_proxy_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_proxy_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_nosec_test - test/core/end2end/fixtures/h2_sockpair.c -) - - -target_include_directories(h2_sockpair_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair+trace_nosec_test - test/core/end2end/fixtures/h2_sockpair+trace.c -) - - -target_include_directories(h2_sockpair+trace_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair+trace_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_1byte_nosec_test - test/core/end2end/fixtures/h2_sockpair_1byte.c -) - - -target_include_directories(h2_sockpair_1byte_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_1byte_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_uds_nosec_test - test/core/end2end/fixtures/h2_uds.c -) - - -target_include_directories(h2_uds_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_uds_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(inproc_nosec_test - test/core/end2end/fixtures/inproc.c -) - - -target_include_directories(inproc_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(inproc_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(api_fuzzer_one_entry - test/core/end2end/fuzzers/api_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(api_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(api_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(client_fuzzer_one_entry - test/core/end2end/fuzzers/client_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(client_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(client_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hpack_parser_fuzzer_test_one_entry - test/core/transport/chttp2/hpack_parser_fuzzer_test.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(hpack_parser_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(hpack_parser_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(http_request_fuzzer_test_one_entry - test/core/http/request_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(http_request_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(http_request_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(http_response_fuzzer_test_one_entry - test/core/http/response_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(http_response_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(http_response_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_fuzzer_test_one_entry - test/core/json/fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(json_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(nanopb_fuzzer_response_test_one_entry - test/core/nanopb/fuzzer_response.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(nanopb_fuzzer_response_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(nanopb_fuzzer_response_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(nanopb_fuzzer_serverlist_test_one_entry - test/core/nanopb/fuzzer_serverlist.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(nanopb_fuzzer_serverlist_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(nanopb_fuzzer_serverlist_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(percent_decode_fuzzer_one_entry - test/core/slice/percent_decode_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(percent_decode_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(percent_decode_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(percent_encode_fuzzer_one_entry - test/core/slice/percent_encode_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(percent_encode_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(percent_encode_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_fuzzer_one_entry - test/core/end2end/fuzzers/server_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(server_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(ssl_server_fuzzer_one_entry - test/core/security/ssl_server_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(ssl_server_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(ssl_server_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(uri_fuzzer_test_one_entry - test/core/client_channel/uri_fuzzer_test.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(uri_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(uri_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - - - - - - - -if (gRPC_INSTALL) - install(EXPORT gRPCTargets - DESTINATION ${gRPC_INSTALL_CMAKEDIR} - NAMESPACE gRPC:: - ) -endif() - -foreach(_config gRPCConfig gRPCConfigVersion) - configure_file(tools/cmake/${_config}.cmake.in - ${_config}.cmake @ONLY) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${_config}.cmake - DESTINATION ${gRPC_INSTALL_CMAKEDIR} - ) -endforeach() diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt index fbd89bad079c5d7f6c2909ca643f4c175428e77f..aaae18a313dd082b428654091c9411600c981ec9 100644 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt @@ -61,9 +61,15 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX") include_directories ("${PROJECT_SOURCE_DIR}/platform/macos") + include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") + # Some versions of MacOS, such as Sierra, require _DARWIN_C_SOURCE + # when including certin C++ standard header files, such as . + add_definitions ("-D_DARWIN_C_SOURCE") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC ${NSYNC_OS_CPP_SRC} + "platform/posix/src/clock_gettime.c" + "platform/posix/src/nsync_semaphore_mutex.c" ) set (NSYNC_TEST_OS_SRC "platform/posix/src/start_thread.c" @@ -138,6 +144,10 @@ if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX") include_directories ("${PROJECT_SOURCE_DIR}/platform/macos") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/clock_gettime.c" + "platform/posix/src/nsync_semaphore_mutex.c" + ) include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX") include_directories ("${PROJECT_SOURCE_DIR}/platform/linux") @@ -148,12 +158,21 @@ if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/netbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/freebsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/openbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) endif () endif () diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt new file mode 100644 index 0000000000000000000000000000000000000000..a0fca690ef6bedc5a872498583dfd0cbb55e2143 --- /dev/null +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -0,0 +1,449 @@ +tensorflow +tensorflow/core +tensorflow/core/example +tensorflow/core/framework +tensorflow/core/lib +tensorflow/core/lib/core +tensorflow/core/protobuf +tensorflow/core/util +tensorflow/examples +tensorflow/examples/tutorials +tensorflow/examples/tutorials/mnist +tensorflow/python +tensorflow/python/client +tensorflow/python/data +tensorflow/python/data/ops +tensorflow/python/data/util +tensorflow/python/debug +tensorflow/python/debug/cli +tensorflow/python/debug/examples +tensorflow/python/debug/lib +tensorflow/python/debug/wrappers +tensorflow/python/eager +tensorflow/python/estimator +tensorflow/python/estimator/canned +tensorflow/python/estimator/export +tensorflow/python/estimator/inputs +tensorflow/python/estimator/inputs/queues +tensorflow/python/feature_column +tensorflow/python/framework +tensorflow/python/grappler +tensorflow/python/keras +tensorflow/python/keras/activations +tensorflow/python/keras/applications +tensorflow/python/keras/applications/inception_resnet_v2 +tensorflow/python/keras/applications/inception_v3 +tensorflow/python/keras/applications/mobilenet +tensorflow/python/keras/applications/resnet50 +tensorflow/python/keras/applications/vgg16 +tensorflow/python/keras/applications/vgg19 +tensorflow/python/keras/applications/xception +tensorflow/python/keras/backend +tensorflow/python/keras/callbacks +tensorflow/python/keras/constraints +tensorflow/python/keras/datasets +tensorflow/python/keras/datasets/boston_housing +tensorflow/python/keras/datasets/cifar10 +tensorflow/python/keras/datasets/cifar100 +tensorflow/python/keras/datasets/fashion_mnist +tensorflow/python/keras/datasets/imdb +tensorflow/python/keras/datasets/mnist +tensorflow/python/keras/datasets/reuters +tensorflow/python/keras/estimator +tensorflow/python/keras/initializers +tensorflow/python/keras/layers +tensorflow/python/keras/losses +tensorflow/python/keras/metrics +tensorflow/python/keras/models +tensorflow/python/keras/optimizers +tensorflow/python/keras/preprocessing +tensorflow/python/keras/preprocessing/image +tensorflow/python/keras/preprocessing/sequence +tensorflow/python/keras/preprocessing/text +tensorflow/python/keras/regularizers +tensorflow/python/keras/utils +tensorflow/python/keras/wrappers +tensorflow/python/keras/wrappers/scikit_learn +tensorflow/python/keras/_impl +tensorflow/python/keras/_impl/keras +tensorflow/python/keras/_impl/keras/applications +tensorflow/python/keras/_impl/keras/datasets +tensorflow/python/keras/_impl/keras/engine +tensorflow/python/keras/_impl/keras/layers +tensorflow/python/keras/_impl/keras/preprocessing +tensorflow/python/keras/_impl/keras/utils +tensorflow/python/keras/_impl/keras/wrappers +tensorflow/python/kernel_tests +tensorflow/python/kernel_tests/distributions +tensorflow/python/kernel_tests/linalg +tensorflow/python/kernel_tests/random +tensorflow/python/layers +tensorflow/python/lib +tensorflow/python/lib/core +tensorflow/python/lib/io +tensorflow/python/ops +tensorflow/python/ops/distributions +tensorflow/python/ops/linalg +tensorflow/python/ops/losses +tensorflow/python/platform +tensorflow/python/platform/default +tensorflow/python/platform/summary +tensorflow/python/profiler/ +tensorflow/python/profiler/internal +tensorflow/python/saved_model +tensorflow/python/summary +tensorflow/python/summary/writer +tensorflow/python/tools +tensorflow/python/training +tensorflow/python/user_ops +tensorflow/python/util +tensorflow/python/util/protobuf +tensorflow/tools +tensorflow/tools/graph_transforms +tensorflow/contrib +tensorflow/contrib/all_reduce +tensorflow/contrib/all_reduce/python +tensorflow/contrib/android +tensorflow/contrib/android/java +tensorflow/contrib/android/java/org +tensorflow/contrib/android/java/org/tensorflow +tensorflow/contrib/android/java/org/tensorflow/contrib +tensorflow/contrib/android/java/org/tensorflow/contrib/android +tensorflow/contrib/android/jni +tensorflow/contrib/batching +tensorflow/contrib/batching/kernels +tensorflow/contrib/batching/python +tensorflow/contrib/batching/python/ops +tensorflow/contrib/bayesflow +tensorflow/contrib/bayesflow/examples +tensorflow/contrib/bayesflow/examples/reinforce_simple +tensorflow/contrib/bayesflow/python +tensorflow/contrib/bayesflow/python/ops +tensorflow/contrib/boosted_trees +tensorflow/contrib/boosted_trees/estimator_batch +tensorflow/contrib/boosted_trees/kernels +tensorflow/contrib/boosted_trees/ops +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/boosted_trees/python +tensorflow/contrib/boosted_trees/python/ops +tensorflow/contrib/cloud +tensorflow/contrib/cloud/kernels +tensorflow/contrib/cloud/ops +tensorflow/contrib/cloud/python +tensorflow/contrib/cloud/python/ops +tensorflow/contrib/cluster_resolver +tensorflow/contrib/cluster_resolver/python +tensorflow/contrib/cluster_resolver/python/training +tensorflow/contrib/compiler +tensorflow/contrib/copy_graph +tensorflow/contrib/copy_graph/python +tensorflow/contrib/copy_graph/python/util +tensorflow/contrib/crf +tensorflow/contrib/crf/python +tensorflow/contrib/crf/python/ops +tensorflow/contrib/cudnn_rnn +tensorflow/contrib/cudnn_rnn/kernels +tensorflow/contrib/cudnn_rnn/ops +tensorflow/contrib/cudnn_rnn/python +tensorflow/contrib/cudnn_rnn/python/layers +tensorflow/contrib/cudnn_rnn/python/ops +tensorflow/contrib/data +tensorflow/contrib/data/kernels +tensorflow/contrib/data/python +tensorflow/contrib/data/python/kernel_tests +tensorflow/contrib/data/python/ops +tensorflow/contrib/decision_trees +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/deprecated +tensorflow/contrib/distributions +tensorflow/contrib/distributions/python +tensorflow/contrib/distributions/python/ops +tensorflow/contrib/distributions/python/ops/bijectors +tensorflow/contrib/eager +tensorflow/contrib/eager/python +tensorflow/contrib/estimator +tensorflow/contrib/estimator/python +tensorflow/contrib/estimator/python/estimator +tensorflow/contrib/factorization +tensorflow/contrib/factorization/examples +tensorflow/contrib/factorization/kernels +tensorflow/contrib/factorization/ops +tensorflow/contrib/factorization/python +tensorflow/contrib/factorization/python/ops +tensorflow/contrib/ffmpeg +tensorflow/contrib/ffmpeg/default +tensorflow/contrib/framework +tensorflow/contrib/framework/kernels +tensorflow/contrib/framework/ops +tensorflow/contrib/framework/python +tensorflow/contrib/framework/python/framework +tensorflow/contrib/framework/python/ops +tensorflow/contrib/fused_conv +tensorflow/contrib/fused_conv/kernels +tensorflow/contrib/fused_conv/python +tensorflow/contrib/fused_conv/python/ops +tensorflow/contrib/gan +tensorflow/contrib/gan/python +tensorflow/contrib/gan/python/estimator +tensorflow/contrib/gan/python/estimator/python +tensorflow/contrib/gan/python/eval +tensorflow/contrib/gan/python/eval/python +tensorflow/contrib/gan/python/features +tensorflow/contrib/gan/python/features/python +tensorflow/contrib/gan/python/losses +tensorflow/contrib/gan/python/losses/python +tensorflow/contrib/graph_editor +tensorflow/contrib/graph_editor/examples +tensorflow/contrib/grid_rnn +tensorflow/contrib/grid_rnn/python +tensorflow/contrib/grid_rnn/python/ops +tensorflow/contrib/hooks +tensorflow/contrib/hooks/python +tensorflow/contrib/image +tensorflow/contrib/image/kernels +tensorflow/contrib/image/ops +tensorflow/contrib/image/python +tensorflow/contrib/image/python/ops +tensorflow/contrib/input_pipeline +tensorflow/contrib/input_pipeline/kernels +tensorflow/contrib/input_pipeline/ops +tensorflow/contrib/input_pipeline/python +tensorflow/contrib/input_pipeline/python/ops +tensorflow/contrib/integrate +tensorflow/contrib/integrate/python +tensorflow/contrib/integrate/python/ops +tensorflow/contrib/ios_examples +tensorflow/contrib/ios_examples/benchmark +tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj +tensorflow/contrib/ios_examples/benchmark/data +tensorflow/contrib/ios_examples/camera +tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj +tensorflow/contrib/ios_examples/camera/en.lproj +tensorflow/contrib/ios_examples/simple +tensorflow/contrib/ios_examples/simple/data +tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj +tensorflow/contrib/keras +tensorflow/contrib/keras/api +tensorflow/contrib/keras/api/keras +tensorflow/contrib/keras/api/keras/activations +tensorflow/contrib/keras/api/keras/applications +tensorflow/contrib/keras/api/keras/applications/inception_v3 +tensorflow/contrib/keras/api/keras/applications/mobilenet +tensorflow/contrib/keras/api/keras/applications/resnet50 +tensorflow/contrib/keras/api/keras/applications/vgg16 +tensorflow/contrib/keras/api/keras/applications/vgg19 +tensorflow/contrib/keras/api/keras/applications/xception +tensorflow/contrib/keras/api/keras/backend +tensorflow/contrib/keras/api/keras/callbacks +tensorflow/contrib/keras/api/keras/constraints +tensorflow/contrib/keras/api/keras/datasets +tensorflow/contrib/keras/api/keras/datasets/boston_housing +tensorflow/contrib/keras/api/keras/datasets/cifar10 +tensorflow/contrib/keras/api/keras/datasets/cifar100 +tensorflow/contrib/keras/api/keras/datasets/imdb +tensorflow/contrib/keras/api/keras/datasets/mnist +tensorflow/contrib/keras/api/keras/datasets/reuters +tensorflow/contrib/keras/api/keras/initializers +tensorflow/contrib/keras/api/keras/layers +tensorflow/contrib/keras/api/keras/losses +tensorflow/contrib/keras/api/keras/metrics +tensorflow/contrib/keras/api/keras/models +tensorflow/contrib/keras/api/keras/optimizers +tensorflow/contrib/keras/api/keras/preprocessing +tensorflow/contrib/keras/api/keras/preprocessing/image +tensorflow/contrib/keras/api/keras/preprocessing/sequence +tensorflow/contrib/keras/api/keras/preprocessing/text +tensorflow/contrib/keras/api/keras/regularizers +tensorflow/contrib/keras/api/keras/utils +tensorflow/contrib/keras/api/keras/wrappers +tensorflow/contrib/keras/api/keras/wrappers/scikit_learn +tensorflow/contrib/kernel_methods +tensorflow/contrib/kernel_methods/python +tensorflow/contrib/kernel_methods/python/mappers +tensorflow/contrib/kfac +tensorflow/contrib/kfac/examples +tensorflow/contrib/kfac/python +tensorflow/contrib/kfac/python/ops +tensorflow/contrib/labeled_tensor +tensorflow/contrib/labeled_tensor/python +tensorflow/contrib/labeled_tensor/python/ops +tensorflow/contrib/layers +tensorflow/contrib/layers/kernels +tensorflow/contrib/layers/ops +tensorflow/contrib/layers/python +tensorflow/contrib/layers/python/layers +tensorflow/contrib/layers/python/ops +tensorflow/contrib/learn +tensorflow/contrib/learn/python +tensorflow/contrib/learn/python/learn +tensorflow/contrib/learn/python/learn/dataframe +tensorflow/contrib/learn/python/learn/dataframe/queues +tensorflow/contrib/learn/python/learn/dataframe/transforms +tensorflow/contrib/learn/python/learn/datasets +tensorflow/contrib/learn/python/learn/datasets/data +tensorflow/contrib/learn/python/learn/estimators +tensorflow/contrib/learn/python/learn/learn_io +tensorflow/contrib/learn/python/learn/ops +tensorflow/contrib/learn/python/learn/preprocessing +tensorflow/contrib/learn/python/learn/utils +tensorflow/contrib/legacy_seq2seq +tensorflow/contrib/legacy_seq2seq/python +tensorflow/contrib/legacy_seq2seq/python/ops +tensorflow/contrib/linalg +tensorflow/contrib/linalg/python +tensorflow/contrib/linalg/python/ops +tensorflow/contrib/linear_optimizer +tensorflow/contrib/linear_optimizer/kernels +tensorflow/contrib/linear_optimizer/kernels/g3doc +tensorflow/contrib/linear_optimizer/python +tensorflow/contrib/linear_optimizer/python/ops +tensorflow/contrib/lookup +tensorflow/contrib/losses +tensorflow/contrib/losses/python +tensorflow/contrib/losses/python/losses +tensorflow/contrib/losses/python/metric_learning +tensorflow/contrib/makefile +tensorflow/contrib/memory_stats +tensorflow/contrib/memory_stats/kernels +tensorflow/contrib/memory_stats/ops +tensorflow/contrib/memory_stats/python +tensorflow/contrib/memory_stats/python/ops +tensorflow/contrib/meta_graph_transform +tensorflow/contrib/metrics +tensorflow/contrib/metrics/ops +tensorflow/contrib/metrics/python +tensorflow/contrib/metrics/python/metrics +tensorflow/contrib/metrics/python/ops +tensorflow/contrib/model_pruning +tensorflow/contrib/model_pruning/examples +tensorflow/contrib/model_pruning/examples/cifar10 +tensorflow/contrib/model_pruning/python +tensorflow/contrib/model_pruning/python/layers +tensorflow/contrib/nccl +tensorflow/contrib/nccl/kernels +tensorflow/contrib/nccl/ops +tensorflow/contrib/nccl/python +tensorflow/contrib/nccl/python/ops +tensorflow/contrib/ndlstm +tensorflow/contrib/ndlstm/python +tensorflow/contrib/nearest_neighbor/kernels +tensorflow/contrib/nearest_neighbor/ops +tensorflow/contrib/nearest_neighbor/python +tensorflow/contrib/nearest_neighbor/python/ops +tensorflow/contrib/nn +tensorflow/contrib/nn/python +tensorflow/contrib/nn/python/ops +tensorflow/contrib/opt +tensorflow/contrib/opt/python +tensorflow/contrib/opt/python/training +tensorflow/contrib/pi_examples +tensorflow/contrib/pi_examples/camera +tensorflow/contrib/pi_examples/label_image +tensorflow/contrib/pi_examples/label_image/data +tensorflow/contrib/periodic_resample +tensorflow/contrib/periodic_resample/python +tensorflow/contrib/periodic_resample/python/kernels +tensorflow/contrib/periodic_resample/python/ops +tensorflow/contrib/predictor +tensorflow/contrib/quantization +tensorflow/contrib/quantization/python +tensorflow/contrib/quantize +tensorflow/contrib/quantize/python +tensorflow/contrib/receptive_field +tensorflow/contrib/receptive_field/python +tensorflow/contrib/reduce_slice_ops +tensorflow/contrib/reduce_slice_ops/kernels +tensorflow/contrib/reduce_slice_ops/ops +tensorflow/contrib/reduce_slice_ops/python +tensorflow/contrib/reduce_slice_ops/python/ops +tensorflow/contrib/remote_fused_graph/pylib +tensorflow/contrib/remote_fused_graph/pylib/python +tensorflow/contrib/remote_fused_graph/pylib/python/ops +tensorflow/contrib/resampler +tensorflow/contrib/resampler/kernels +tensorflow/contrib/resampler/ops +tensorflow/contrib/resampler/python +tensorflow/contrib/resampler/python/ops +tensorflow/contrib/rnn +tensorflow/contrib/rnn/kernels +tensorflow/contrib/rnn/ops +tensorflow/contrib/rnn/python +tensorflow/contrib/rnn/python/kernel_tests +tensorflow/contrib/rnn/python/ops +tensorflow/contrib/saved_model +tensorflow/contrib/saved_model/python +tensorflow/contrib/saved_model/python/saved_model +tensorflow/contrib/seq2seq +tensorflow/contrib/seq2seq/kernels +tensorflow/contrib/seq2seq/ops +tensorflow/contrib/seq2seq/python +tensorflow/contrib/seq2seq/python/ops +tensorflow/contrib/session_bundle +tensorflow/contrib/session_bundle/example +tensorflow/contrib/signal +tensorflow/contrib/signal/python +tensorflow/contrib/signal/python/ops +tensorflow/contrib/slim +tensorflow/contrib/slim/python +tensorflow/contrib/slim/python/slim +tensorflow/contrib/slim/python/slim/data +tensorflow/contrib/slim/python/slim/nets +tensorflow/contrib/solvers +tensorflow/contrib/solvers/python +tensorflow/contrib/solvers/python/ops +tensorflow/contrib/sparsemax +tensorflow/contrib/sparsemax/python +tensorflow/contrib/sparsemax/python/ops +tensorflow/contrib/specs +tensorflow/contrib/specs/python +tensorflow/contrib/staging +tensorflow/contrib/stat_summarizer +tensorflow/contrib/stat_summarizer/python +tensorflow/contrib/stateless +tensorflow/contrib/stateless/python +tensorflow/contrib/summary +tensorflow/contrib/tensorboard +tensorflow/contrib/tensorboard/plugins +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensor_forest +tensorflow/contrib/tensor_forest/client +tensorflow/contrib/tensor_forest/core +tensorflow/contrib/tensor_forest/core/ops +tensorflow/contrib/tensor_forest/data +tensorflow/contrib/tensor_forest/hybrid +tensorflow/contrib/tensor_forest/hybrid/core +tensorflow/contrib/tensor_forest/hybrid/core/ops +tensorflow/contrib/tensor_forest/hybrid/ops +tensorflow/contrib/tensor_forest/hybrid/python +tensorflow/contrib/tensor_forest/hybrid/python/layers +tensorflow/contrib/tensor_forest/hybrid/python/models +tensorflow/contrib/tensor_forest/hybrid/python/ops +tensorflow/contrib/tensor_forest/kernels +tensorflow/contrib/tensor_forest/python +tensorflow/contrib/tensor_forest/python/ops +tensorflow/contrib/testing +tensorflow/contrib/testing/python +tensorflow/contrib/testing/python/framework +tensorflow/contrib/text +tensorflow/contrib/text/kernels +tensorflow/contrib/text/ops +tensorflow/contrib/text/python +tensorflow/contrib/text/python/ops +tensorflow/contrib/tfprof +tensorflow/contrib/timeseries +tensorflow/contrib/timeseries/examples +tensorflow/contrib/timeseries/examples/data +tensorflow/contrib/timeseries/python +tensorflow/contrib/timeseries/python/timeseries +tensorflow/contrib/timeseries/python/timeseries/state_space_models +tensorflow/contrib/tpu +tensorflow/contrib/tpu/ops +tensorflow/contrib/tpu/profiler +tensorflow/contrib/tpu/python +tensorflow/contrib/tpu/python/ops +tensorflow/contrib/tpu/python/profiler +tensorflow/contrib/tpu/python/tpu +tensorflow/contrib/training +tensorflow/contrib/training/python +tensorflow/contrib/training/python/training +tensorflow/contrib/util diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt new file mode 100644 index 0000000000000000000000000000000000000000..8a9c406d8b118c10ddcaafb0e4fc242aa79cdb57 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -0,0 +1,19 @@ +tensorflow/core +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/cloud/kernels +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/gdr +tensorflow/contrib/lite/toco +tensorflow/contrib/mpi +tensorflow/contrib/mpi_collectives +tensorflow/contrib/session_bundle +tensorflow/contrib/tensor_forest/proto +tensorflow/contrib/tensorboard/graph_explorer/proto +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensorboard/plugins/trace +tensorflow/contrib/tpu/proto +tensorflow/contrib/tpu/profiler +tensorflow/contrib/training/python/training +tensorflow/contrib/verbs diff --git a/tensorflow/contrib/cmake/python_protos_cc.txt b/tensorflow/contrib/cmake/python_protos_cc.txt new file mode 100644 index 0000000000000000000000000000000000000000..d4a257b25c814a1464308d0e6ce3ce65d21f6a36 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos_cc.txt @@ -0,0 +1,5 @@ +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/session_bundle +tensorflow/contrib/tensorboard +tensorflow/contrib/training diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index f3882e8cf76c6dad31371fc340de959c05411a2f..c6a15f2ca075c8de96786a580c7ddb89541df5bc 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -21,7 +21,6 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" - "${tensorflow_source_dir}/tensorflow/c/eager/tape.cc" "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" @@ -47,4 +46,5 @@ add_dependencies( tf_c_python_api tf_c tf_core_lib + tf_core_framework tf_protos_cc) diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index f63aca4a835e213ef6d420845df9bb537514e142..6e2ac203f9a7f96cb14752a91483840a9eb6b451 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -83,7 +83,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names}) ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.cc - COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal} + COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal} ${tensorflow_source_dir}/tensorflow/core/api_def/base_api DEPENDS ${tf_cc_op_lib_name}_gen_cc create_cc_ops_header_dir ) diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 5c01ca382fb9cc7a01a6f2b60a510c59f0aa7119..e4213ea2a47da2a7381cccd0504235ad62018d4e 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -63,7 +63,7 @@ if (tensorflow_ENABLE_GPU) file(GLOB_RECURSE tf_core_gpu_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/*.cc" "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu/cupti_wrapper.cc" - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc" + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.h" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.cc" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index c3dc8531bb9f0164f06841d9715f227202fdb7c9..5ec1a8d04fa41c6b36400fc0998af77592866150 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -211,7 +211,7 @@ if (NOT tensorflow_ENABLE_GPU) list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs}) else() file(GLOB tf_core_platform_srcs_exclude - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc") + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc") list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude}) endif() @@ -301,6 +301,8 @@ file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.h" "${tensorflow_source_dir}/public/*.h" ) @@ -314,6 +316,7 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/util/*test*.h" "${tensorflow_source_dir}/tensorflow/core/util/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/util/*main.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc" ) list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs}) diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index a2ab4b9ae4fc1e491e180840407c0a5238e5623a..eb6bf567aa7dc2e87f3d5ce462a7680fc9850bbf 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -55,10 +55,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/model_ops.cc" @@ -70,7 +66,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" @@ -155,9 +150,6 @@ list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_exclude_srcs}) if(WIN32) file(GLOB_RECURSE tf_core_kernels_windows_exclude_srcs # not working on windows yet - "${tensorflow_source_dir}/tensorflow/core/kernels/meta_support.*" - "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.h" - "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/neon/*" # not in core - those are loaded dynamically as dll "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 03c168795cc2455327f0b7bbf40fd1fd1eebb34e..e8c2cd347327843d10d13c1d24a800ff776aa8c1 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -81,7 +81,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") @@ -93,6 +92,7 @@ GENERATE_CONTRIB_OP_LIBRARY(image_sirds "${tensorflow_source_dir}/tensorflow/con GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(periodic_resample "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nearest_neighbor "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/ops/nearest_neighbor_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(resampler "${tensorflow_source_dir}/tensorflow/contrib/resampler/ops/resampler_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_grappler.cmake b/tensorflow/contrib/cmake/tf_grappler.cmake index a7841c98e83ec8c3eb91edfd9d639e169cb5f440..410490531a300c091afdd857d7f2d4e789a4c80e 100644 --- a/tensorflow/contrib/cmake/tf_grappler.cmake +++ b/tensorflow/contrib/cmake/tf_grappler.cmake @@ -23,7 +23,7 @@ file(GLOB tf_grappler_srcs "${tensorflow_source_dir}/tensorflow/python/grappler/model_analyzer.cc" "${tensorflow_source_dir}/tensorflow/python/grappler/model_analyzer.h" ) - + add_library(tf_grappler OBJECT ${tf_grappler_srcs}) add_dependencies(tf_grappler tf_core_cpu) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 43b98659e347bc53a76eb2a6138f6636aad974d8..8db6929e31a1a5f5c793721f455a664bd6741b06 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -120,32 +120,34 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR) set(${HDRS} ${${HDRS}} PARENT_SCOPE) endfunction() -file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/*.proto" - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/decision_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos.txt python_protos) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos "${python_protos}") +STRING(REGEX REPLACE "\n" ";" python_protos "${python_protos}") + +foreach(python_proto ${python_protos}) + file(GLOB_RECURSE tf_python_protos_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto}/*.proto" + ) + list(APPEND tf_python_protos_srcs ${tf_python_protos_src}) +endforeach(python_proto) + RELATIVE_PROTOBUF_GENERATE_PYTHON( - ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_protos_python_srcs} + ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_python_protos_srcs} ) -# NOTE(mrry): Avoid regenerating the tensorflow/core protos because this -# can cause benign-but-failing-on-Windows-due-to-file-locking conflicts -# when two rules attempt to generate the same file. -file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos_cc.txt python_protos_cc) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos_cc "${python_protos_cc}") +STRING(REGEX REPLACE "\n" ";" python_protos_cc "${python_protos_cc}") + +foreach(python_proto_cc ${python_protos_cc}) + file(GLOB_RECURSE tf_python_protos_cc_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto_cc}/*.proto" + ) + list(APPEND tf_python_protos_cc_srcs ${tf_python_protos_cc_src}) +endforeach(python_proto_cc) + RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS ${tensorflow_source_dir} ${tf_python_protos_cc_srcs} ) @@ -191,465 +193,28 @@ function(add_python_module MODULE_NAME) endif() endfunction() -add_python_module("tensorflow") -add_python_module("tensorflow/core") -add_python_module("tensorflow/core/example") -add_python_module("tensorflow/core/framework") -add_python_module("tensorflow/core/lib") -add_python_module("tensorflow/core/lib/core") -add_python_module("tensorflow/core/protobuf") -add_python_module("tensorflow/core/util") -add_python_module("tensorflow/examples") -add_python_module("tensorflow/examples/tutorials") -add_python_module("tensorflow/examples/tutorials/mnist") -add_python_module("tensorflow/python") -add_python_module("tensorflow/python/client") -add_python_module("tensorflow/python/data") -add_python_module("tensorflow/python/data/ops") -add_python_module("tensorflow/python/data/util") -add_python_module("tensorflow/python/debug") -add_python_module("tensorflow/python/debug/cli") -add_python_module("tensorflow/python/debug/examples") -add_python_module("tensorflow/python/debug/lib") -add_python_module("tensorflow/python/debug/wrappers") -add_python_module("tensorflow/python/eager") -add_python_module("tensorflow/python/estimator") -add_python_module("tensorflow/python/estimator/canned") -add_python_module("tensorflow/python/estimator/export") -add_python_module("tensorflow/python/estimator/inputs") -add_python_module("tensorflow/python/estimator/inputs/queues") -add_python_module("tensorflow/python/feature_column") -add_python_module("tensorflow/python/framework") -add_python_module("tensorflow/python/grappler") -add_python_module("tensorflow/python/keras") -add_python_module("tensorflow/python/keras/activations") -add_python_module("tensorflow/python/keras/applications") -add_python_module("tensorflow/python/keras/applications/inception_resnet_v2") -add_python_module("tensorflow/python/keras/applications/inception_v3") -add_python_module("tensorflow/python/keras/applications/mobilenet") -add_python_module("tensorflow/python/keras/applications/resnet50") -add_python_module("tensorflow/python/keras/applications/vgg16") -add_python_module("tensorflow/python/keras/applications/vgg19") -add_python_module("tensorflow/python/keras/applications/xception") -add_python_module("tensorflow/python/keras/backend") -add_python_module("tensorflow/python/keras/callbacks") -add_python_module("tensorflow/python/keras/constraints") -add_python_module("tensorflow/python/keras/datasets") -add_python_module("tensorflow/python/keras/datasets/boston_housing") -add_python_module("tensorflow/python/keras/datasets/cifar10") -add_python_module("tensorflow/python/keras/datasets/cifar100") -add_python_module("tensorflow/python/keras/datasets/imdb") -add_python_module("tensorflow/python/keras/datasets/mnist") -add_python_module("tensorflow/python/keras/datasets/reuters") -add_python_module("tensorflow/python/keras/estimator") -add_python_module("tensorflow/python/keras/initializers") -add_python_module("tensorflow/python/keras/layers") -add_python_module("tensorflow/python/keras/losses") -add_python_module("tensorflow/python/keras/metrics") -add_python_module("tensorflow/python/keras/models") -add_python_module("tensorflow/python/keras/optimizers") -add_python_module("tensorflow/python/keras/preprocessing") -add_python_module("tensorflow/python/keras/preprocessing/image") -add_python_module("tensorflow/python/keras/preprocessing/sequence") -add_python_module("tensorflow/python/keras/preprocessing/text") -add_python_module("tensorflow/python/keras/regularizers") -add_python_module("tensorflow/python/keras/utils") -add_python_module("tensorflow/python/keras/wrappers") -add_python_module("tensorflow/python/keras/wrappers/scikit_learn") -add_python_module("tensorflow/python/keras/_impl") -add_python_module("tensorflow/python/keras/_impl/keras") -add_python_module("tensorflow/python/keras/_impl/keras/applications") -add_python_module("tensorflow/python/keras/_impl/keras/datasets") -add_python_module("tensorflow/python/keras/_impl/keras/engine") -add_python_module("tensorflow/python/keras/_impl/keras/layers") -add_python_module("tensorflow/python/keras/_impl/keras/preprocessing") -add_python_module("tensorflow/python/keras/_impl/keras/utils") -add_python_module("tensorflow/python/keras/_impl/keras/wrappers") -add_python_module("tensorflow/python/kernel_tests") -add_python_module("tensorflow/python/kernel_tests/distributions") -add_python_module("tensorflow/python/kernel_tests/linalg") -add_python_module("tensorflow/python/layers") -add_python_module("tensorflow/python/lib") -add_python_module("tensorflow/python/lib/core") -add_python_module("tensorflow/python/lib/io") -add_python_module("tensorflow/python/ops") -add_python_module("tensorflow/python/ops/distributions") -add_python_module("tensorflow/python/ops/linalg") -add_python_module("tensorflow/python/ops/losses") -add_python_module("tensorflow/python/platform") -add_python_module("tensorflow/python/platform/default") -add_python_module("tensorflow/python/platform/summary") -add_python_module("tensorflow/python/profiler/") -add_python_module("tensorflow/python/profiler/internal") -add_python_module("tensorflow/python/saved_model") -add_python_module("tensorflow/python/summary") -add_python_module("tensorflow/python/summary/writer") -add_python_module("tensorflow/python/tools") -add_python_module("tensorflow/python/training") -add_python_module("tensorflow/python/user_ops") -add_python_module("tensorflow/python/util") -add_python_module("tensorflow/python/util/protobuf") -add_python_module("tensorflow/tools") -add_python_module("tensorflow/tools/graph_transforms") -add_python_module("tensorflow/contrib") -add_python_module("tensorflow/contrib/all_reduce") -add_python_module("tensorflow/contrib/all_reduce/python") -add_python_module("tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/java") -add_python_module("tensorflow/contrib/android/java/org") -add_python_module("tensorflow/contrib/android/java/org/tensorflow") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/jni") -add_python_module("tensorflow/contrib/bayesflow") -add_python_module("tensorflow/contrib/bayesflow/examples") -add_python_module("tensorflow/contrib/bayesflow/examples/reinforce_simple") -add_python_module("tensorflow/contrib/bayesflow/python") -add_python_module("tensorflow/contrib/bayesflow/python/kernel_tests") -add_python_module("tensorflow/contrib/bayesflow/python/ops") -add_python_module("tensorflow/contrib/boosted_trees") -add_python_module("tensorflow/contrib/boosted_trees/estimator_batch") -add_python_module("tensorflow/contrib/boosted_trees/ops") -add_python_module("tensorflow/contrib/boosted_trees/proto") -add_python_module("tensorflow/contrib/boosted_trees/python") -add_python_module("tensorflow/contrib/boosted_trees/python/kernel_tests") -add_python_module("tensorflow/contrib/boosted_trees/python/ops") -add_python_module("tensorflow/contrib/cloud") -add_python_module("tensorflow/contrib/cloud/kernels") -add_python_module("tensorflow/contrib/cloud/ops") -add_python_module("tensorflow/contrib/cloud/python") -add_python_module("tensorflow/contrib/cloud/python/ops") -add_python_module("tensorflow/contrib/cluster_resolver") -add_python_module("tensorflow/contrib/cluster_resolver/python") -add_python_module("tensorflow/contrib/cluster_resolver/python/training") -add_python_module("tensorflow/contrib/compiler") -add_python_module("tensorflow/contrib/copy_graph") -add_python_module("tensorflow/contrib/copy_graph/python") -add_python_module("tensorflow/contrib/copy_graph/python/util") -add_python_module("tensorflow/contrib/crf") -add_python_module("tensorflow/contrib/crf/python") -add_python_module("tensorflow/contrib/crf/python/kernel_tests") -add_python_module("tensorflow/contrib/crf/python/ops") -add_python_module("tensorflow/contrib/cudnn_rnn") -add_python_module("tensorflow/contrib/cudnn_rnn/kernels") -add_python_module("tensorflow/contrib/cudnn_rnn/ops") -add_python_module("tensorflow/contrib/cudnn_rnn/python") -add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/cudnn_rnn/python/layers") -add_python_module("tensorflow/contrib/cudnn_rnn/python/ops") -add_python_module("tensorflow/contrib/data") -add_python_module("tensorflow/contrib/data/python") -add_python_module("tensorflow/contrib/data/python/kernel_tests") -add_python_module("tensorflow/contrib/data/python/ops") -add_python_module("tensorflow/contrib/decision_trees") -add_python_module("tensorflow/contrib/decision_trees/proto") -add_python_module("tensorflow/contrib/deprecated") -add_python_module("tensorflow/contrib/distributions") -add_python_module("tensorflow/contrib/distributions/python") -add_python_module("tensorflow/contrib/distributions/python/kernel_tests") -add_python_module("tensorflow/contrib/distributions/python/ops") -add_python_module("tensorflow/contrib/distributions/python/ops/bijectors") -add_python_module("tensorflow/contrib/eager") -add_python_module("tensorflow/contrib/eager/python") -add_python_module("tensorflow/contrib/estimator") -add_python_module("tensorflow/contrib/estimator/python") -add_python_module("tensorflow/contrib/estimator/python/estimator") -add_python_module("tensorflow/contrib/factorization") -add_python_module("tensorflow/contrib/factorization/examples") -add_python_module("tensorflow/contrib/factorization/kernels") -add_python_module("tensorflow/contrib/factorization/ops") -add_python_module("tensorflow/contrib/factorization/python") -add_python_module("tensorflow/contrib/factorization/python/kernel_tests") -add_python_module("tensorflow/contrib/factorization/python/ops") -add_python_module("tensorflow/contrib/ffmpeg") -add_python_module("tensorflow/contrib/ffmpeg/default") -add_python_module("tensorflow/contrib/ffmpeg/testdata") -add_python_module("tensorflow/contrib/framework") -add_python_module("tensorflow/contrib/framework/kernels") -add_python_module("tensorflow/contrib/framework/ops") -add_python_module("tensorflow/contrib/framework/python") -add_python_module("tensorflow/contrib/framework/python/framework") -add_python_module("tensorflow/contrib/framework/python/ops") -add_python_module("tensorflow/contrib/gan") -add_python_module("tensorflow/contrib/gan/python") -add_python_module("tensorflow/contrib/gan/python/eval") -add_python_module("tensorflow/contrib/gan/python/eval/python") -add_python_module("tensorflow/contrib/gan/python/features") -add_python_module("tensorflow/contrib/gan/python/features/python") -add_python_module("tensorflow/contrib/gan/python/estimator") -add_python_module("tensorflow/contrib/gan/python/estimator/python") -add_python_module("tensorflow/contrib/gan/python/losses") -add_python_module("tensorflow/contrib/gan/python/losses/python") -add_python_module("tensorflow/contrib/graph_editor") -add_python_module("tensorflow/contrib/graph_editor/examples") -add_python_module("tensorflow/contrib/graph_editor/tests") -add_python_module("tensorflow/contrib/grid_rnn") -add_python_module("tensorflow/contrib/grid_rnn/python") -add_python_module("tensorflow/contrib/grid_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/grid_rnn/python/ops") -add_python_module("tensorflow/contrib/hooks") -add_python_module("tensorflow/contrib/image") -add_python_module("tensorflow/contrib/image/ops") -add_python_module("tensorflow/contrib/image/python") -add_python_module("tensorflow/contrib/image/python/ops") -add_python_module("tensorflow/contrib/input_pipeline") -add_python_module("tensorflow/contrib/input_pipeline/ops") -add_python_module("tensorflow/contrib/input_pipeline/python") -add_python_module("tensorflow/contrib/input_pipeline/python/ops") -add_python_module("tensorflow/contrib/integrate") -add_python_module("tensorflow/contrib/integrate/python") -add_python_module("tensorflow/contrib/integrate/python/ops") -add_python_module("tensorflow/contrib/ios_examples") -add_python_module("tensorflow/contrib/ios_examples/benchmark") -add_python_module("tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/benchmark/data") -add_python_module("tensorflow/contrib/ios_examples/camera") -add_python_module("tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/camera/en.lproj") -add_python_module("tensorflow/contrib/ios_examples/simple") -add_python_module("tensorflow/contrib/ios_examples/simple/data") -add_python_module("tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj") -add_python_module("tensorflow/contrib/keras") -add_python_module("tensorflow/contrib/keras/api") -add_python_module("tensorflow/contrib/keras/api/keras") -add_python_module("tensorflow/contrib/keras/api/keras/activations") -add_python_module("tensorflow/contrib/keras/api/keras/applications") -add_python_module("tensorflow/contrib/keras/api/keras/applications/inception_v3") -add_python_module("tensorflow/contrib/keras/api/keras/applications/mobilenet") -add_python_module("tensorflow/contrib/keras/api/keras/applications/resnet50") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg16") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg19") -add_python_module("tensorflow/contrib/keras/api/keras/applications/xception") -add_python_module("tensorflow/contrib/keras/api/keras/backend") -add_python_module("tensorflow/contrib/keras/api/keras/callbacks") -add_python_module("tensorflow/contrib/keras/api/keras/constraints") -add_python_module("tensorflow/contrib/keras/api/keras/datasets") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/boston_housing") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar10") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar100") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/imdb") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/mnist") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/reuters") -add_python_module("tensorflow/contrib/keras/api/keras/initializers") -add_python_module("tensorflow/contrib/keras/api/keras/layers") -add_python_module("tensorflow/contrib/keras/api/keras/losses") -add_python_module("tensorflow/contrib/keras/api/keras/metrics") -add_python_module("tensorflow/contrib/keras/api/keras/models") -add_python_module("tensorflow/contrib/keras/api/keras/optimizers") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/image") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/sequence") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/text") -add_python_module("tensorflow/contrib/keras/api/keras/regularizers") -add_python_module("tensorflow/contrib/keras/api/keras/utils") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers/scikit_learn") -add_python_module("tensorflow/contrib/keras/python") -add_python_module("tensorflow/contrib/keras/python/keras") -add_python_module("tensorflow/contrib/keras/python/keras/applications") -add_python_module("tensorflow/contrib/keras/python/keras/datasets") -add_python_module("tensorflow/contrib/keras/python/keras/engine") -add_python_module("tensorflow/contrib/keras/python/keras/layers") -add_python_module("tensorflow/contrib/keras/python/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/python/keras/utils") -add_python_module("tensorflow/contrib/keras/python/keras/wrappers") -add_python_module("tensorflow/contrib/kernel_methods") -add_python_module("tensorflow/contrib/kernel_methods/python") -add_python_module("tensorflow/contrib/kernel_methods/python/mappers") -add_python_module("tensorflow/contrib/kfac") -add_python_module("tensorflow/contrib/kfac/examples") -add_python_module("tensorflow/contrib/kfac/python") -add_python_module("tensorflow/contrib/kfac/python/ops") -add_python_module("tensorflow/contrib/labeled_tensor") -add_python_module("tensorflow/contrib/labeled_tensor/python") -add_python_module("tensorflow/contrib/labeled_tensor/python/ops") -add_python_module("tensorflow/contrib/layers") -add_python_module("tensorflow/contrib/layers/kernels") -add_python_module("tensorflow/contrib/layers/ops") -add_python_module("tensorflow/contrib/layers/python") -add_python_module("tensorflow/contrib/layers/python/kernel_tests") -add_python_module("tensorflow/contrib/layers/python/layers") -add_python_module("tensorflow/contrib/layers/python/ops") -add_python_module("tensorflow/contrib/learn") -add_python_module("tensorflow/contrib/learn/python") -add_python_module("tensorflow/contrib/learn/python/learn") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/queues") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/transforms") -add_python_module("tensorflow/contrib/learn/python/learn/datasets") -add_python_module("tensorflow/contrib/learn/python/learn/datasets/data") -add_python_module("tensorflow/contrib/learn/python/learn/estimators") -add_python_module("tensorflow/contrib/learn/python/learn/learn_io") -add_python_module("tensorflow/contrib/learn/python/learn/ops") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/utils") -add_python_module("tensorflow/contrib/legacy_seq2seq") -add_python_module("tensorflow/contrib/legacy_seq2seq/python") -add_python_module("tensorflow/contrib/legacy_seq2seq/python/ops") -add_python_module("tensorflow/contrib/linalg") -add_python_module("tensorflow/contrib/linalg/python") -add_python_module("tensorflow/contrib/linalg/python/ops") -add_python_module("tensorflow/contrib/linalg/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer") -add_python_module("tensorflow/contrib/linear_optimizer/kernels") -add_python_module("tensorflow/contrib/linear_optimizer/kernels/g3doc") -add_python_module("tensorflow/contrib/linear_optimizer/python") -add_python_module("tensorflow/contrib/linear_optimizer/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer/python/ops") -add_python_module("tensorflow/contrib/lookup") -add_python_module("tensorflow/contrib/losses") -add_python_module("tensorflow/contrib/losses/python") -add_python_module("tensorflow/contrib/losses/python/losses") -add_python_module("tensorflow/contrib/losses/python/metric_learning") -add_python_module("tensorflow/contrib/makefile") -add_python_module("tensorflow/contrib/makefile/test") -add_python_module("tensorflow/contrib/memory_stats") -add_python_module("tensorflow/contrib/memory_stats/kernels") -add_python_module("tensorflow/contrib/memory_stats/ops") -add_python_module("tensorflow/contrib/memory_stats/python") -add_python_module("tensorflow/contrib/memory_stats/python/kernel_tests") -add_python_module("tensorflow/contrib/memory_stats/python/ops") -add_python_module("tensorflow/contrib/meta_graph_transform") -add_python_module("tensorflow/contrib/metrics") -add_python_module("tensorflow/contrib/metrics/kernels") -add_python_module("tensorflow/contrib/metrics/ops") -add_python_module("tensorflow/contrib/metrics/python") -add_python_module("tensorflow/contrib/metrics/python/kernel_tests") -add_python_module("tensorflow/contrib/metrics/python/metrics") -add_python_module("tensorflow/contrib/metrics/python/ops") -add_python_module("tensorflow/contrib/model_pruning") -add_python_module("tensorflow/contrib/model_pruning/examples") -add_python_module("tensorflow/contrib/model_pruning/examples/cifar10") -add_python_module("tensorflow/contrib/model_pruning/python") -add_python_module("tensorflow/contrib/model_pruning/python/layers") -add_python_module("tensorflow/contrib/ndlstm") -add_python_module("tensorflow/contrib/ndlstm/python") -add_python_module("tensorflow/contrib/nn") -add_python_module("tensorflow/contrib/nn/python") -add_python_module("tensorflow/contrib/nn/python/ops") -add_python_module("tensorflow/contrib/nccl") -add_python_module("tensorflow/contrib/nccl/kernels") -add_python_module("tensorflow/contrib/nccl/ops") -add_python_module("tensorflow/contrib/nccl/python") -add_python_module("tensorflow/contrib/nccl/python/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/kernels") -add_python_module("tensorflow/contrib/nearest_neighbor/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/python") -add_python_module("tensorflow/contrib/nearest_neighbor/python/kernel_tests") -add_python_module("tensorflow/contrib/nearest_neighbor/python/ops") -add_python_module("tensorflow/contrib/opt") -add_python_module("tensorflow/contrib/opt/python") -add_python_module("tensorflow/contrib/opt/python/training") -add_python_module("tensorflow/contrib/pi_examples") -add_python_module("tensorflow/contrib/pi_examples/camera") -add_python_module("tensorflow/contrib/pi_examples/label_image") -add_python_module("tensorflow/contrib/pi_examples/label_image/data") -add_python_module("tensorflow/contrib/predictor") -add_python_module("tensorflow/contrib/quantization") -add_python_module("tensorflow/contrib/quantization/python") -add_python_module("tensorflow/contrib/quantize") -add_python_module("tensorflow/contrib/quantize/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python/ops") -add_python_module("tensorflow/contrib/resampler") -add_python_module("tensorflow/contrib/resampler/kernels") -add_python_module("tensorflow/contrib/resampler/ops") -add_python_module("tensorflow/contrib/resampler/python") -add_python_module("tensorflow/contrib/resampler/python/ops") -add_python_module("tensorflow/contrib/rnn") -add_python_module("tensorflow/contrib/rnn/kernels") -add_python_module("tensorflow/contrib/rnn/ops") -add_python_module("tensorflow/contrib/rnn/python") -add_python_module("tensorflow/contrib/rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/rnn/python/ops") -add_python_module("tensorflow/contrib/saved_model") -add_python_module("tensorflow/contrib/saved_model/python") -add_python_module("tensorflow/contrib/saved_model/python/saved_model") -add_python_module("tensorflow/contrib/seq2seq") -add_python_module("tensorflow/contrib/seq2seq/kernels") -add_python_module("tensorflow/contrib/seq2seq/ops") -add_python_module("tensorflow/contrib/seq2seq/python") -add_python_module("tensorflow/contrib/seq2seq/python/kernel_tests") -add_python_module("tensorflow/contrib/seq2seq/python/ops") -add_python_module("tensorflow/contrib/session_bundle") -add_python_module("tensorflow/contrib/session_bundle/example") -add_python_module("tensorflow/contrib/session_bundle/testdata") -add_python_module("tensorflow/contrib/signal") -add_python_module("tensorflow/contrib/signal/python") -add_python_module("tensorflow/contrib/signal/python/ops") -add_python_module("tensorflow/contrib/slim") -add_python_module("tensorflow/contrib/slim/python") -add_python_module("tensorflow/contrib/slim/python/slim") -add_python_module("tensorflow/contrib/slim/python/slim/data") -add_python_module("tensorflow/contrib/slim/python/slim/nets") -add_python_module("tensorflow/contrib/solvers") -add_python_module("tensorflow/contrib/solvers/python") -add_python_module("tensorflow/contrib/solvers/python/ops") -add_python_module("tensorflow/contrib/sparsemax") -add_python_module("tensorflow/contrib/sparsemax/python") -add_python_module("tensorflow/contrib/sparsemax/python/ops") -add_python_module("tensorflow/contrib/specs") -add_python_module("tensorflow/contrib/specs/python") -add_python_module("tensorflow/contrib/staging") -add_python_module("tensorflow/contrib/stat_summarizer") -add_python_module("tensorflow/contrib/stateless") -add_python_module("tensorflow/contrib/tensorboard") -add_python_module("tensorflow/contrib/tensorboard/plugins") -add_python_module("tensorflow/contrib/tensorboard/plugins/projector") -add_python_module("tensorflow/contrib/tensor_forest") -add_python_module("tensorflow/contrib/tensor_forest/client") -add_python_module("tensorflow/contrib/tensor_forest/core") -add_python_module("tensorflow/contrib/tensor_forest/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/data") -add_python_module("tensorflow/contrib/tensor_forest/hybrid") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/layers") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/models") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/ops") -add_python_module("tensorflow/contrib/tensor_forest/python") -add_python_module("tensorflow/contrib/tensor_forest/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/python/ops") -add_python_module("tensorflow/contrib/testing") -add_python_module("tensorflow/contrib/testing/python") -add_python_module("tensorflow/contrib/testing/python/framework") -add_python_module("tensorflow/contrib/text") -add_python_module("tensorflow/contrib/text/kernels") -add_python_module("tensorflow/contrib/text/ops") -add_python_module("tensorflow/contrib/text/python") -add_python_module("tensorflow/contrib/text/python/ops") -add_python_module("tensorflow/contrib/tfprof") -add_python_module("tensorflow/contrib/timeseries") -add_python_module("tensorflow/contrib/timeseries/examples") -add_python_module("tensorflow/contrib/timeseries/examples/data") -add_python_module("tensorflow/contrib/timeseries/python") -add_python_module("tensorflow/contrib/timeseries/python/timeseries") -add_python_module("tensorflow/contrib/timeseries/python/timeseries/state_space_models") -add_python_module("tensorflow/contrib/tpu") -add_python_module("tensorflow/contrib/tpu/ops") -add_python_module("tensorflow/contrib/tpu/profiler") -add_python_module("tensorflow/contrib/tpu/python") -add_python_module("tensorflow/contrib/tpu/python/ops") -add_python_module("tensorflow/contrib/tpu/python/profiler") -add_python_module("tensorflow/contrib/tpu/python/tpu") -add_python_module("tensorflow/contrib/training") -add_python_module("tensorflow/contrib/training/python") -add_python_module("tensorflow/contrib/training/python/training") -add_python_module("tensorflow/contrib/util") -add_python_module("tensorflow/contrib/reduce_slice_ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/kernels") -add_python_module("tensorflow/contrib/reduce_slice_ops/ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/python") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/kernel_tests") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/ops") -add_python_module("tensorflow/contrib/summary") +FILE(READ python_modules.txt python_modules) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_modules "${python_modules}") +STRING(REGEX REPLACE "\n" ";" python_modules "${python_modules}") + +foreach(python_module ${python_modules}) + add_python_module(${python_module}) +endforeach(python_module) + +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite") +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python") +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E touch + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/__init__.py") +add_custom_command( + TARGET tf_python_copy_scripts_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E touch + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/lite.py) # Generate the tensorflow.python.platform.build_info module. set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py") @@ -724,7 +289,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) # containing the wrappers. add_custom_command( OUTPUT ${GENERATE_PYTHON_OP_LIB_DESTINATION} - COMMAND ${tf_python_op_lib_name}_gen_python @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} + COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} DEPENDS ${tf_python_op_lib_name}_gen_python ) @@ -780,8 +345,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" @@ -804,6 +367,9 @@ GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_periodic_resample_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/gen_periodic_resample_op.py) + GENERATE_PYTHON_OP_LIB("contrib_nearest_neighbor_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nearest_neighbor/ops/gen_nearest_neighbor_ops.py) GENERATE_PYTHON_OP_LIB("contrib_resampler_ops" @@ -876,6 +442,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor.h" @@ -886,6 +454,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.cc" "${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_reader.h" @@ -1002,6 +572,20 @@ target_link_libraries(pywrap_tensorflow_internal PRIVATE ) if(WIN32) + + # include contrib/periodic_resample as .so + # + set(tf_periodic_resample_srcs + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc" + ) + + AddUserOps(TARGET _periodic_resample_op + SOURCES "${tf_periodic_resample_srcs}" + DEPENDS pywrap_tensorflow_internal tf_python_ops + DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/) + # include contrib/nearest_neighbor as .so # set(tf_nearest_neighbor_srcs diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 3e3fe0cdfae3e286be6601928a922a436429bbe6..571d2b0decb5e9afcec2314f9837546f0974e90d 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -45,7 +45,7 @@ if(WIN32) $ $ ) - + set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/tensorflow.def") set_source_files_properties(${tensorflow_deffile} PROPERTIES GENERATED TRUE) @@ -95,10 +95,18 @@ if(WIN32) add_dependencies(tensorflow tensorflow_static) endif(WIN32) -install(TARGETS tensorflow +target_include_directories(tensorflow PUBLIC + $ + $) + +install(TARGETS tensorflow EXPORT tensorflow_export RUNTIME DESTINATION bin LIBRARY DESTINATION lib ARCHIVE DESTINATION lib) + +install(EXPORT tensorflow_export + FILE TensorflowConfig.cmake + DESTINATION lib/cmake) # install necessary headers # tensorflow headers diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake index 8d95f0d3e813885c581b37cfc0b89e24d04ae6b1..91ca33f4c4d5f6c822f45b0676e6e46d2e4c2860 100644 --- a/tensorflow/contrib/cmake/tf_stream_executor.cmake +++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake @@ -61,18 +61,18 @@ file(GLOB tf_stream_executor_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/platform/default/*.h" ) -if (tensorflow_ENABLE_GPU) +if (tensorflow_ENABLE_GPU) file(GLOB tf_stream_executor_gpu_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*.cc" ) list(APPEND tf_stream_executor_srcs ${tf_stream_executor_gpu_srcs}) -endif() +endif() #file(GLOB_RECURSE tf_stream_executor_test_srcs # "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.cc" # "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.h" #) -#list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) +#list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) if (NOT WIN32) set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lgomp") diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 5d6ba9ca8d85e9a2d19b7f3e488822a8f21c6821..94ca4b00175dffb4461fca34c5ecd79ba79be778 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -139,12 +139,15 @@ if (tensorflow_BUILD_PYTHON_TESTS) file(GLOB_RECURSE tf_test_src_py ${tf_test_rnn_src_py} + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/debug/cli/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/*_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py" + "${tensorflow_source_dir}/tensorflow/python/ops/quantized_conv_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/ops/quantized_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/platform/build_info_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/*_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/internal/*_test.py" @@ -153,7 +156,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/*_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/integration_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/stateless/python/kernel_tests/*_test.py" @@ -171,7 +175,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/distributions/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/learn/*_test.py" ) @@ -217,6 +220,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) # TFDBG grpc:// mode is not yet available on Windows. "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/source_remote_test.py" # stl on windows handles overflows different "${tensorflow_source_dir}/tensorflow/python/kernel_tests/as_string_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/string_to_number_op_test.py" @@ -225,6 +229,10 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Numerical issues, calculations off. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/backend_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py" # Float division by zero "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" # Flaky, for unknown reasons. Cannot reproduce in terminal. Revisit once we can get stack traces. @@ -233,11 +241,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" "${tensorflow_source_dir}tensorflow/python/training/localhost_cluster_performance_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py" # IteratorGetMax OutOfRangeError @@ -261,9 +269,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg_grad_test.py" # cudaSolver handle creation fails. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops # Dataset tests - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" # Broken tensorboard test due to cmake issues. "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 @@ -294,6 +302,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Test should only be run manually "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/svd_op_test.py" + # Depends on python/framework/test_ops + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/control_flow_util_test.py" ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 8c2528f548799f9facef740b0134ac56966b2b04..bae66ffd4289308f2cbfc730ec50d057b13923fb 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -19,7 +19,7 @@ from one graph to another. The copied elements are initialized inside a user-specified scope in the other graph. There are separate functions to copy ops and variables. There is also a function to retrive the copied version of an op from the -first graph inside a scope in the second graph. +first graph inside a scope in the second graph. @@copy_op_to_graph @@copy_variable_to_graph @@ -225,7 +225,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, new_original_op, op_def) #Use Graph's hidden methods to add the op - to_graph._add_op(new_op) + to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 9174c5eb989908d5a318e228bf231686b5117798..b47fb426a193e0fcc075deafae3eaab698f18ec9 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -23,7 +23,6 @@ import itertools import numpy as np from tensorflow.contrib.crf.python.ops import crf -from tensorflow.python.framework import dtypes from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -33,43 +32,58 @@ from tensorflow.python.platform import test class CrfTest(test.TestCase): def testCrfSequenceScore(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) - tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - sequence_score = crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - sequence_score = array_ops.squeeze(sequence_score, [0]) - tf_sequence_score = sess.run(sequence_score) - expected_unary_score = sum(inputs[i][tag_indices[i]] - for i in range(sequence_lengths)) - expected_binary_score = sum( - transition_params[tag_indices[i], tag_indices[i + 1]] - for i in range(sequence_lengths - 1)) - expected_sequence_score = expected_unary_score + expected_binary_score - self.assertAllClose(tf_sequence_score, expected_sequence_score) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([1], dtype=np.int32) + ] + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + with self.test_session() as sess: + sequence_score = crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + sequence_score = array_ops.squeeze(sequence_score, [0]) + tf_sequence_score = sess.run(sequence_score) + expected_unary_score = sum(inputs[i][tag_indices[i]] + for i in range(sequence_lengths)) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + expected_sequence_score = expected_unary_score + expected_binary_score + self.assertAllClose(tf_sequence_score, expected_sequence_score) def testCrfUnaryScore(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) - tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - unary_score = crf.crf_unary_score( - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - inputs=array_ops.expand_dims(inputs, 0)) - unary_score = array_ops.squeeze(unary_score, [0]) - tf_unary_score = sess.run(unary_score) - expected_unary_score = sum(inputs[i][tag_indices[i]] - for i in range(sequence_lengths)) - self.assertAllClose(tf_unary_score, expected_unary_score) + for dtype in (np.int32, np.int64): + tag_indices = np.array([1, 2, 1, 0], dtype=dtype) + sequence_lengths = np.array(3, dtype=np.int32) + with self.test_session() as sess: + unary_score = crf.crf_unary_score( + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + inputs=array_ops.expand_dims(inputs, 0)) + unary_score = array_ops.squeeze(unary_score, [0]) + tf_unary_score = sess.run(unary_score) + expected_unary_score = sum(inputs[i][tag_indices[i]] + for i in range(sequence_lengths)) + self.assertAllClose(tf_unary_score, expected_unary_score) def testCrfBinaryScore(self): tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) @@ -89,38 +103,54 @@ class CrfTest(test.TestCase): self.assertAllClose(tf_binary_score, expected_binary_score) def testCrfLogNorm(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - num_words = inputs.shape[0] - num_tags = inputs.shape[1] - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - all_sequence_scores = [] - - # Compare the dynamic program with brute force computation. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - all_sequence_scores.append( - crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params))) - - brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores) - log_norm = crf.crf_log_norm( - inputs=array_ops.expand_dims(inputs, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - log_norm = array_ops.squeeze(log_norm, [0]) - tf_brute_force_log_norm, tf_log_norm = sess.run( - [brute_force_log_norm, log_norm]) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[3, -1, 3]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + with self.test_session() as sess: + all_sequence_scores = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequence_scores.append( + crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params))) + + brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores) + log_norm = crf.crf_log_norm( + inputs=array_ops.expand_dims(inputs, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + log_norm = array_ops.squeeze(log_norm, [0]) + tf_brute_force_log_norm, tf_log_norm = sess.run( + [brute_force_log_norm, log_norm]) - self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) + self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) def testCrfLogLikelihood(self): inputs = np.array( @@ -201,50 +231,66 @@ class CrfTest(test.TestCase): expected_max_sequence[:sequence_lengths]) def testCrfDecode(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - sequence_lengths = np.array(3, dtype=np.int32) - num_words = inputs.shape[0] - num_tags = inputs.shape[1] - - with self.test_session() as sess: - all_sequence_scores = [] - all_sequences = [] - - # Compare the dynamic program with brute force computation. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - all_sequences.append(tag_indices) - sequence_score = crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - sequence_score = array_ops.squeeze(sequence_score, [0]) - all_sequence_scores.append(sequence_score) - - tf_all_sequence_scores = sess.run(all_sequence_scores) - - expected_max_sequence_index = np.argmax(tf_all_sequence_scores) - expected_max_sequence = all_sequences[expected_max_sequence_index] - expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] - - actual_max_sequence, actual_max_score = crf.crf_decode( - array_ops.expand_dims(inputs, 0), - constant_op.constant(transition_params), - array_ops.expand_dims(sequence_lengths, 0)) - actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0]) - actual_max_score = array_ops.squeeze(actual_max_score, [0]) - tf_actual_max_sequence, tf_actual_max_score = sess.run( - [actual_max_sequence, actual_max_score]) - - self.assertAllClose(tf_actual_max_score, expected_max_score) - self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), - expected_max_sequence[:sequence_lengths]) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[-1, 2, 1]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + + with self.test_session() as sess: + all_sequence_scores = [] + all_sequences = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + sequence_score = array_ops.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + tf_all_sequence_scores = sess.run(all_sequence_scores) + + expected_max_sequence_index = np.argmax(tf_all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] + + actual_max_sequence, actual_max_score = crf.crf_decode( + array_ops.expand_dims(inputs, 0), + constant_op.constant(transition_params), + array_ops.expand_dims(sequence_lengths, 0)) + actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0]) + actual_max_score = array_ops.squeeze(actual_max_score, [0]) + tf_actual_max_sequence, tf_actual_max_score = sess.run( + [actual_max_sequence, actual_max_score]) + + self.assertAllClose(tf_actual_max_score, expected_max_score) + self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), + expected_max_sequence[:sequence_lengths]) if __name__ == "__main__": diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index c8adb0369b98947d2d29374ee8ada1185815d3cd..7f5ae937b26f465076c6976429697c35924432e5 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -53,7 +53,9 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn @@ -101,12 +103,29 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, Returns: sequence_scores: A [batch_size] vector of unnormalized sequence scores. """ - # Compute the scores of the given tag sequence. - unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) - binary_scores = crf_binary_score(tag_indices, sequence_lengths, - transition_params) - sequence_scores = unary_scores + binary_scores - return sequence_scores + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of the single tag. + def _single_seq_fn(): + batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0] + example_inds = array_ops.reshape( + math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) + return array_ops.gather_nd( + array_ops.squeeze(inputs, [1]), + array_ops.concat([example_inds, tag_indices], axis=1)) + + def _multi_seq_fn(): + # Compute the scores of the given tag sequence. + unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) + binary_scores = crf_binary_score(tag_indices, sequence_lengths, + transition_params) + sequence_scores = unary_scores + binary_scores + return sequence_scores + + return utils.smart_cond( + pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], + 1), + fn1=_single_seq_fn, + fn2=_multi_seq_fn) def crf_log_norm(inputs, sequence_lengths, transition_params): @@ -124,19 +143,32 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): # algorithm. first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1]) first_input = array_ops.squeeze(first_input, [1]) - rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) - # Compute the alpha values in the forward algorithm in order to get the - # partition function. - forward_cell = CrfForwardRnnCell(transition_params) - _, alphas = rnn.dynamic_rnn( - cell=forward_cell, - inputs=rest_of_input, - sequence_length=sequence_lengths - 1, - initial_state=first_input, - dtype=dtypes.float32) - log_norm = math_ops.reduce_logsumexp(alphas, [1]) - return log_norm + # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over + # the "initial state" (the unary potentials). + def _single_seq_fn(): + return math_ops.reduce_logsumexp(first_input, [1]) + + def _multi_seq_fn(): + """Forward computation of alpha values.""" + rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) + + # Compute the alpha values in the forward algorithm in order to get the + # partition function. + forward_cell = CrfForwardRnnCell(transition_params) + _, alphas = rnn.dynamic_rnn( + cell=forward_cell, + inputs=rest_of_input, + sequence_length=sequence_lengths - 1, + initial_state=first_input, + dtype=dtypes.float32) + log_norm = math_ops.reduce_logsumexp(alphas, [1]) + return log_norm + + max_seq_len = array_ops.shape(inputs)[1] + return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1), + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) def crf_log_likelihood(inputs, @@ -193,6 +225,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): offsets = array_ops.expand_dims( math_ops.range(batch_size) * max_seq_len * num_tags, 1) offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0) + # Use int32 or int64 based on tag_indices' dtype. + if tag_indices.dtype == dtypes.int64: + offsets = math_ops.to_int64(offsets) flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1]) unary_scores = array_ops.reshape( @@ -305,7 +340,7 @@ def viterbi_decode(score, transition_params): Returns: viterbi: A [seq_len] list of integers containing the highest scoring tag - indicies. + indices. viterbi_score: A float containing the score for the Viterbi sequence. """ trellis = np.zeros_like(score) @@ -385,7 +420,7 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell): """Initialize the CrfDecodeBackwardRnnCell. Args: - num_tags: An integer. + num_tags: An integer. The number of tags. """ self._num_tags = num_tags @@ -435,44 +470,63 @@ def crf_decode(potentials, transition_params, sequence_length): Returns: decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. - Contains the highest scoring tag indicies. + Contains the highest scoring tag indices. best_score: A [batch_size] vector, containing the score of `decode_tags`. """ - # For simplicity, in shape comments, denote: - # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). - num_tags = potentials.get_shape()[2].value - - # Computes forward decoding. Get last score and backpointers. - crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) - initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) - initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] - inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] - backpointers, last_score = rnn.dynamic_rnn( - crf_fwd_cell, - inputs=inputs, - sequence_length=sequence_length - 1, - initial_state=initial_state, - time_major=False, - dtype=dtypes.int32) # [B, T - 1, O], [B, O] - backpointers = gen_array_ops.reverse_sequence( - backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O] - - # Computes backward decoding. Extract tag indices from backpointers. - crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) - initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), - dtype=dtypes.int32) # [B] - initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] - decode_tags, _ = rnn.dynamic_rnn( - crf_bwd_cell, - inputs=backpointers, - sequence_length=sequence_length - 1, - initial_state=initial_state, - time_major=False, - dtype=dtypes.int32) # [B, T - 1, 1] - decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] - decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T] - decode_tags = gen_array_ops.reverse_sequence( - decode_tags, sequence_length, seq_dim=1) # [B, T] - - best_score = math_ops.reduce_max(last_score, axis=1) # [B] - return decode_tags, best_score + # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag + # and the max activation. + def _single_seq_fn(): + squeezed_potentials = array_ops.squeeze(potentials, [1]) + decode_tags = array_ops.expand_dims( + math_ops.argmax(squeezed_potentials, axis=1), 1) + best_score = math_ops.reduce_max(squeezed_potentials, axis=1) + return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score + + def _multi_seq_fn(): + """Decoding of highest scoring sequence.""" + + # For simplicity, in shape comments, denote: + # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). + num_tags = potentials.get_shape()[2].value + + # Computes forward decoding. Get last score and backpointers. + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) + initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) + initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] + inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] + backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] + crf_fwd_cell, + inputs=inputs, + sequence_length=sequence_length - 1, + initial_state=initial_state, + time_major=False, + dtype=dtypes.int32) + backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O] + backpointers, sequence_length - 1, seq_dim=1) + + # Computes backward decoding. Extract tag indices from backpointers. + crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) + initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B] + dtype=dtypes.int32) + initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] + decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1] + crf_bwd_cell, + inputs=backpointers, + sequence_length=sequence_length - 1, + initial_state=initial_state, + time_major=False, + dtype=dtypes.int32) + decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] + decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T] + axis=1) + decode_tags = gen_array_ops.reverse_sequence( # [B, T] + decode_tags, sequence_length, seq_dim=1) + + best_score = math_ops.reduce_max(last_score, axis=1) # [B] + return decode_tags, best_score + + return utils.smart_cond( + pred=math_ops.equal( + potentials.shape[1].value or array_ops.shape(potentials)[1], 1), + fn1=_single_seq_fn, + fn2=_multi_seq_fn) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index d6d53d521b2024abf50cfbfec96a6e0dc538ed03..fce2c03e69bc4b8b0ac46b8e081a33c43c9d41ab 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -54,49 +54,13 @@ tf_gen_op_wrapper_py( deps = [":cudnn_rnn_ops_op_lib"], ) -tf_custom_op_py_library( - name = "cudnn_rnn_ops_py", - srcs = [ - "__init__.py", - "python/ops/cudnn_rnn_ops.py", - ], - dso = [ - ":python/ops/_cudnn_rnn_ops.so", - ], - kernels = [ - ":cudnn_rnn_kernels", - ":cudnn_rnn_ops_op_lib", - ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":cudnn_rnn_ops", - "//tensorflow/contrib/rnn:rnn_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:common_shapes", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers_base", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:rnn_cell", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - ], -) - tf_custom_op_py_library( name = "cudnn_rnn_py", srcs = [ "__init__.py", "python/layers/__init__.py", "python/layers/cudnn_rnn.py", + "python/ops/cudnn_rnn_ops.py", ], dso = [ ":python/ops/_cudnn_rnn_ops.so", @@ -109,7 +73,6 @@ tf_custom_op_py_library( visibility = ["//visibility:public"], deps = [ ":cudnn_rnn_ops", - ":cudnn_rnn_ops_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -130,7 +93,7 @@ cuda_py_test( size = "large", srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], additional_deps = [ - ":cudnn_rnn_ops_py", + ":cudnn_rnn_py", "//tensorflow/core:protos_all_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python/ops/losses:losses", diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 1f7efad71fb04cd754eae8ce170e696baa4d7fc3..5d8c6191f8db9f96532aa78e4790a4665d3b4877 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -29,19 +29,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.cudnn_rnn.python.layers import * # pylint: enable=unused-import,wildcard-import -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable - from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 9156087f338f0f59f102560d7538b1871c84e23e..5a667485beebe4bee7f051b5920920c72134987f 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -35,15 +35,11 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import rnn as rnn_lib -from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables -from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import gradient_descent from tensorflow.python.training import saver as saver_lib CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION @@ -123,45 +119,6 @@ def _CreateParamsSavable(params, return params_saveable -def _BuildCudnnForward(rnn_mode, - num_layers, - num_units, - input_data, - is_training=False): - input_data_shape = input_data.get_shape().with_rank(3) - batch_size = input_data_shape[1].value - input_size = input_data_shape[2].value - model = _CreateModel(rnn_mode, num_layers, num_units, input_size) - - # Set zero init input states - input_h = constant_op.constant( - np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - if has_input_c: - input_c = constant_op.constant( - np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32) - - # Set rnn params - params_size_t = model.params_size() - params = variables.Variable( - random_ops.random_uniform([params_size_t]), validate_shape=False) - args = { - "input_data": input_data, - "input_h": input_h, - "params": params, - "is_training": is_training - } - if has_input_c: - args["input_c"] = input_c - # Build cell - output_tuple = model(**args) - - # Create savable objects for params - _CreateParamsSavable(params, model) - - return output_tuple, model - - def _MinLSTMParamSize(num_layers, num_units, input_size, @@ -181,25 +138,6 @@ def _MinLSTMParamSize(num_layers, raise ValueError("%s direction is not supported.") -def _CreateCudnnCompatibleCanonicalRNN(cudnn_model, - inputs, - scope=None): - model = cudnn_model.rnn_mode - if model not in (cudnn_rnn_ops.CUDNN_LSTM, cudnn_rnn_ops.CUDNN_GRU): - raise ValueError("%s is not supported!" % model) - - num_units = cudnn_model.num_units - num_layers = cudnn_model.num_layers - # To reuse cuDNN-trained models, must use cudnn compatible rnn cells. - if model == cudnn_rnn_ops.CUDNN_LSTM: - single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleLSTMCell(num_units) - else: - single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units) - cell = rnn_cell_impl.MultiRNNCell([single_cell() for _ in range(num_layers)]) - return rnn_lib.dynamic_rnn( - cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope) - - class CudnnRNNTestSaveRestore(TensorFlowTestCase): def _CompareWeights(self, lhs, rhs): @@ -436,143 +374,6 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): self._testSaveRestoreOutput(rnn_mode, direction, dtype) -class CudnnRNNTestCompatibleRnnCells(TensorFlowTestCase): - - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") - def testCudnnCompatibleRnnCells(self): - configs = [ - { - "num_layers": 1, - "seq_length": 3, - "num_units": 4, - "input_size": 5, - "batch_size": 6, - }, - { - "num_layers": 2, - "seq_length": 8, - "num_units": 4, - "input_size": 8, - "batch_size": 16, - }, - { - "num_layers": 2, - "seq_length": 3, - "num_units": 4, - "input_size": 5, - "batch_size": 6, - }, - { - "num_layers": 1, - "seq_length": 2, - "num_units": 2, - "input_size": 4, - "batch_size": 1, - }, - ] - for rnn, cfg in itertools.product((cudnn_rnn_ops.CUDNN_LSTM,), configs): - self._testCudnnCompatibleRnnCells(cfg["num_layers"], cfg["seq_length"], - cfg["num_units"], cfg["input_size"], - cfg["batch_size"], rnn) - # TODO(jamesqin): Add CudnnCompatibleGRUBlockCell. - for rnn, cfg in itertools.product((cudnn_rnn_ops.CUDNN_GRU,), configs): - self._testCudnnCompatibleRnnCells(cfg["num_layers"], cfg["seq_length"], - cfg["num_units"], cfg["input_size"], - cfg["batch_size"], rnn) - - def _testCudnnCompatibleRnnCells(self, num_layers, seq_length, num_units, - input_size, batch_size, rnn_mode): - has_state_c = rnn_mode == cudnn_rnn_ops.CUDNN_LSTM - np.random.seed(0) - # Train graph - with ops.Graph().as_default(): - random_seed.set_random_seed(299) - input_data = array_ops.placeholder( - dtypes.float32, shape=[seq_length, batch_size, input_size]) - output_tuple, cudnn_model = _BuildCudnnForward( - rnn_mode, num_layers, num_units, input_data, is_training=True) - target_output = array_ops.placeholder(dtype=dtypes.float32, shape=None) - total_sum = sum(map(math_ops.reduce_sum, output_tuple)) - - loss_op = losses.log_loss(labels=target_output, predictions=total_sum) - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2) - train_op = optimizer.minimize(loss_op) - - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - - # Train Cudnn model - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - # Train 128 steps - num_steps = 128 - for _ in range(num_steps): - inputs = np.random.rand(seq_length, batch_size, - input_size).astype(np.float32) - targets = np.random.rand() - sess.run( - train_op, feed_dict={input_data: inputs, - target_output: targets}) - - save_path = os.path.join(self.get_temp_dir(), - ("cudnn-rnn-%s-test" % rnn_mode)) - save_v = saver.save(sess, save_path) - self.assertEqual(save_path, save_v) - - # cuDNN inference graph - with ops.Graph().as_default(): - random_seed.set_random_seed(299) - cudnn_inputs = array_ops.placeholder( - dtypes.float32, shape=[seq_length, batch_size, input_size]) - (cudnn_output_tuple, cudnn_model) = _BuildCudnnForward( - rnn_mode, num_layers, num_units, cudnn_inputs, is_training=False) - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - - inference_input = np.random.rand(seq_length, batch_size, - input_size).astype(np.float32) - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - saver.restore(sess, save_path) - - # Cudnn inference - cudnn_output = sess.run( - cudnn_output_tuple, feed_dict={cudnn_inputs: inference_input}) - - # Canonical RNN inference graph - with ops.Graph().as_default(): - random_seed.set_random_seed(299) - cell_inputs = array_ops.placeholder( - dtypes.float32, shape=[seq_length, batch_size, input_size]) - (output, states) = _CreateCudnnCompatibleCanonicalRNN( - cudnn_model, cell_inputs) - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - saver.restore(sess, save_path) - - # BlockCell inference - output_v, states_v = sess.run( - [output, states], feed_dict={cell_inputs: inference_input}) - - # output across timestamps are packed into one tensor. - self.assertAllClose(cudnn_output[0], output_v, atol=1e-6, rtol=1e-6) - - for i in range(num_layers): - if has_state_c: - # output_h - self.assertAllClose( - cudnn_output[1][i, :], states_v[i].h, atol=1e-6, rtol=1e-6) - # output_c - self.assertAllClose( - cudnn_output[2][i, :], states_v[i].c, atol=1e-6, rtol=1e-6) - else: - self.assertAllClose( - cudnn_output[1][i, :], states_v[i], atol=1e-6, rtol=1e-6) - - class CudnnRNNTestParamsSize(TensorFlowTestCase): def _testOneLSTMParamsSize(self, num_layers, num_units, input_size, diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py index 5feee3d10d14020d63eec0541e5caa37e79f9f57..f09466b631f69d6234573dd5eafada650421c117 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py @@ -22,3 +22,10 @@ import sys # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.cudnn_rnn.python.layers.cudnn_rnn import * # pylint: enable=unused-import,wildcard-import + +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 9f748996934ca608838e57756a96c35c67feaac9..dcd3d4732a27ae4bec579ac12ac568dc4a53baaa 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops -from tensorflow.contrib.rnn.python.ops import core_rnn_cell from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import common_shapes @@ -29,6 +28,7 @@ from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs @@ -55,6 +55,11 @@ CUDNN_INPUT_LINEAR_MODE = "linear_input" CUDNN_INPUT_SKIP_MODE = "skip_input" CUDNN_INPUT_AUTO_MODE = "auto_select" +# pylint:disable=protected-access +_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME +_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME +# pylint:enable=protected-access + class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): """Cudnn Compatible LSTMCell. @@ -87,9 +92,9 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): Cudnn compatible GRU (from Cudnn library user guide): ```python r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr) # reset gate - i_t = sigma(x_t * W_i + h_t-1 * R_i + b_Wi + b_Ru) # update gate + u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru) # update gate h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh) # new memory gate - h_t = (1 - i_t) .* h'_t + i_t .* h_t-1 + h_t = (1 - u_t) .* h'_t + u_t .* h_t-1 ``` Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}): @@ -100,9 +105,6 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): ```python r .* (h * R) != (r .* h) * R ``` - - TODO(jamesqin): update the impl after Cudnn 7.1 when Nvidia would adopt the - canonical version compatible with other tf GRU cells. """ def __init__(self, num_units, reuse=None, kernel_initializer=None): @@ -112,33 +114,65 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): reuse=reuse, kernel_initializer=kernel_initializer) + def build(self, inputs_shape): + if inputs_shape[1].value is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" + % inputs_shape) + + input_depth = inputs_shape[1].value + self._gate_kernel = self.add_variable( + "gates/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + self._num_units, 2 * self._num_units], + initializer=self._kernel_initializer) + self._gate_bias = self.add_variable( + "gates/%s" % _BIAS_VARIABLE_NAME, + shape=[2 * self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.constant_initializer(1.0, dtype=self.dtype))) + + self._candidate_input_kernel = self.add_variable( + "candidate/input_projection/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth, self._num_units], + initializer=self._kernel_initializer) + self._candidate_hidden_kernel = self.add_variable( + "candidate/hidden_projection/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[self._num_units, self._num_units], + initializer=self._kernel_initializer) + + self._candidate_input_bias = self.add_variable( + "candidate/input_projection/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.zeros_initializer(dtype=self.dtype))) + self._candidate_hidden_bias = self.add_variable( + "candidate/hidden_projection/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.zeros_initializer(dtype=self.dtype))) + def call(self, inputs, state): """Gated recurrent unit (GRU) with nunits cells.""" - with vs.variable_scope("gates"): # Reset gate and update gate. - # We start with bias of 1.0 to not reset and not update. - bias_ones = self._bias_initializer - if self._bias_initializer is None: - dtype = inputs.dtype - bias_ones = init_ops.constant_initializer(1.0, dtype=dtype) - # pylint: disable=protected-access - value = math_ops.sigmoid( - core_rnn_cell._linear([inputs, state], 2 * self._num_units, True, - bias_ones, self._kernel_initializer)) - r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) - # pylint: enable=protected-access - with vs.variable_scope("candidate"): - # pylint: disable=protected-access - with vs.variable_scope("input_projection"): - hi = core_rnn_cell._linear(inputs, self._num_units, True, - self._bias_initializer, - self._kernel_initializer) - with vs.variable_scope("hidden_projection"): - hh = r * (core_rnn_cell._linear(state, self._num_units, True, - self._bias_initializer, - self._kernel_initializer)) - # pylint: enable=protected-access - c = self._activation(hi + hh) - new_h = u * state + (1 - u) * c + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, state], 1), self._gate_kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) + + value = math_ops.sigmoid(gate_inputs) + r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) + + candidate = nn_ops.bias_add( + math_ops.matmul(inputs, self._candidate_input_kernel), + self._candidate_input_bias) + candidate += r * nn_ops.bias_add( + math_ops.matmul(state, self._candidate_hidden_kernel), + self._candidate_hidden_bias) + candidate = self._activation(candidate) + new_h = (1-u) * candidate + u * state return new_h, new_h diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 7bcf5a5f4dcd6293644725a2ccf78a763da3d9eb..3b1c33063f1214b68f79560f50d56bf5d31c9560 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -17,8 +17,8 @@ py_library( deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", - "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", @@ -27,33 +27,20 @@ py_library( tf_custom_op_library( name = "_prefetching_ops.so", - srcs = [ - "ops/prefetching_ops.cc", - ], - deps = [ - "//tensorflow/contrib/data/kernels:prefetching_kernels", - ], -) - -# TODO(mrry): Move the kernels out of the core library into this library. -tf_custom_op_library( - name = "_dataset_ops.so", - srcs = [ - "ops/dataset_ops.cc", - ], + srcs = ["ops/prefetching_ops.cc"], + deps = ["//tensorflow/contrib/data/kernels:prefetching_kernels"], ) tf_gen_op_libs( - op_lib_names = [ - "dataset_ops", - "prefetching_ops", - ], + op_lib_names = ["prefetching_ops"], ) filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 0c7e793689204ba18dcab03c87902103e5802e45..c9ad091bd44d6e3a9368e182c3df9fc1c6e48071 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -17,12 +17,14 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Dataset +@@Counter @@Iterator @@TFRecordDataset @@FixedLengthRecordDataset @@TextLineDataset @@batch_and_drop_remainder +@@padded_batch_and_drop_remainder @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window @@ -32,6 +34,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@unbatch @@parallel_interleave @@rejection_resample +@@scan @@sloppy_interleave @@get_single_element @@ -41,11 +44,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - # pylint: disable=unused-import + from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch +from tensorflow.contrib.data.python.ops.batching import padded_batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import unbatch +from tensorflow.contrib.data.python.ops.counter import Counter from tensorflow.contrib.data.python.ops.dataset_ops import Dataset from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset @@ -60,6 +65,8 @@ from tensorflow.contrib.data.python.ops.readers import SqlDataset from tensorflow.contrib.data.python.ops.readers import TextLineDataset from tensorflow.contrib.data.python.ops.readers import TFRecordDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample +from tensorflow.contrib.data.python.ops.scan_ops import scan +from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.python.data.ops.iterator_ops import Iterator # pylint: enable=unused-import diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc deleted file mode 100644 index 1574384cb2bf5578bc5ccd13d2792e30b6359996..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ /dev/null @@ -1,232 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_def_builder.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -// -------------------------------------------------------------------------- - -// The ops in this section can be composed to define an input -// pipeline. Each op produces a DT_VARIANT tensor that represents -// a DAG of "dataset" objects. An "dataset" object can be converted -// to a stateful "iterator" by passing the "dataset" to the -// "MakeIterator" op. -// -// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are -// not presently serializable. To avoid issues with constant folding, ensure -// that any "source dataset" ops (i.e. ops that output a dataset and do not -// take one as input) are marked "stateful". - -REGISTER_OP("IgnoreErrorsDataset") - .Input("input_dataset: variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that contains the elements of `input_dataset` ignoring errors. -)doc"); - -REGISTER_OP("MapAndBatchDataset") - .Input("input_dataset: variant") - .Input("other_arguments: Targuments") - .Input("batch_size: int64") - .Input("num_parallel_batches: int64") - .Output("handle: variant") - .Attr("f: func") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that applies `f` to the outputs of `input_dataset` and then -batches `batch_size` of them. - -Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up -to `batch_size * num_parallel_batches` copies of `f` in parallel. - -batch_size: A scalar representing the number of elements to accumulate in a - batch. It determines the number of concurrent invocations of `f` that process - elements from `input_dataset` in parallel. -num_parallel_batches: A scalar representing the number of batches to create in - parallel. Processing multiple batches in parallel benefits workloads prone to - stragglers. -)doc"); - -REGISTER_OP("ScanDataset") - .Input("input_dataset: variant") - .Input("initial_state: Tstate") - .Input("other_arguments: Targuments") - .Output("handle: variant") - .Attr("f: func") - .Attr("Tstate: list(type) >= 1") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset successively reduces `f` over the elements of `input_dataset`. -)doc"); - -REGISTER_OP("ParallelInterleaveDataset") - .Input("input_dataset: variant") - .Input("other_arguments: Targuments") - .Input("cycle_length: int64") - .Input("block_length: int64") - .Input("sloppy: bool") - .Output("handle: variant") - .Attr("f: func") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that applies `f` to the outputs of `input_dataset`. - -The resulting dataset is similar to the `InterleaveDataset`, with the exception -that if retrieving the next value from a dataset would cause the requester to -block, it will skip that input dataset. This dataset is especially useful -when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it -allows the training step to proceed so long as some data is available. - -!! WARNING !! This dataset is not deterministic! - -f: A function mapping elements of `input_dataset`, concatenated with - `other_arguments`, to a Dataset variant that contains elements matching - `output_types` and `output_shapes`. -)doc"); - -REGISTER_OP("GroupByWindowDataset") - .Input("input_dataset: variant") - .Input("key_func_other_arguments: Tkey_func_other_arguments") - .Input("reduce_func_other_arguments: Treduce_func_other_arguments") - .Input( - "window_size_func_other_arguments: Twindow_size_func_other_arguments") - .Output("handle: variant") - .Attr("key_func: func") - .Attr("reduce_func: func") - .Attr("window_size_func: func") - .Attr("Tkey_func_other_arguments: list(type) >= 0") - .Attr("Treduce_func_other_arguments: list(type) >= 0") - .Attr("Twindow_size_func_other_arguments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that computes a windowed group-by on `input_dataset`. - -// TODO(mrry): Support non-int64 keys. - -key_func: A function mapping an element of `input_dataset`, concatenated - with `key_func_other_arguments` to a scalar value of type DT_INT64. -)doc"); - -REGISTER_OP("DenseToSparseBatchDataset") - .Input("input_dataset: variant") - .Input("batch_size: int64") - .Input("row_shape: int64") - .Output("handle: variant") - // NOTE(mrry): the 0th and 2nd elements will be DT_INT64. - .Attr("output_types: list(type) >= 1") - // NOTE(mrry): the 1st and 2nd elements will be vectors. - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that yields a SparseTensor for each element of the input. - -input_dataset: A handle to an input dataset. Must have a single component. -batch_size: A scalar representing the number of elements to accumulate in a - batch. -row_shape: A vector representing the dense shape of each row in the produced - SparseTensor. The shape may be partially specified, using `-1` to indicate - that a particular dimension should use the maximum size of all batch elements. -)doc"); - -REGISTER_OP("SqlDataset") - .Input("driver_name: string") - .Input("data_source_name: string") - .Input("query: string") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked - // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that executes a SQL query and emits rows of the result set. - -driver_name: The database type. Currently, the only supported type is 'sqlite'. -data_source_name: A connection string to connect to the database. -query: A SQL query to execute. -)doc"); - -REGISTER_OP("DatasetToSingleElement") - .Input("dataset: variant") - .Output("components: output_types") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast(i), output_shape_handle); - } - return Status::OK(); - }) - .Doc(R"doc( -Outputs the single element from the given dataset. - -dataset: A handle to a dataset that contains a single element. -components: The components of the single element of `input`. -)doc"); - -REGISTER_OP("SerializeIterator") - .Input("resource_handle: resource") - .Output("serialized: variant") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Converts the given `resource_handle` representing an iterator to a variant tensor. - -resource_handle: A handle to an iterator resource. -serialized: A variant tensor storing the state of the iterator contained in the - resource. -)doc"); - -REGISTER_OP("DeserializeIterator") - .Input("resource_handle: resource") - .Input("serialized: variant") - .SetShapeFn(shape_inference::NoOutputs) - .Doc(R"doc( -Converts the given variant tensor to an iterator and stores it in the given resource. - -resource_handle: A handle to an iterator resource. -serialized: A variant tensor storing the state of the iterator contained in the - resource. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 5877f42dcf9e99bca27ba0e6ce222c556dfbd159..375e3ad61293f0b4599fbdb81e79a956acb03dac 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") py_test( name = "batch_dataset_op_test", @@ -97,8 +97,8 @@ py_test( "nomac", # b/62040583 ], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -110,7 +110,7 @@ py_test( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", + "//tensorflow/python:tensor_shape", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -123,14 +123,17 @@ py_library( "dataset_serialization_test_base.py", ], srcs_version = "PY2AND3", - visibility = ["//visibility:private"], deps = [ "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", ], ) @@ -140,7 +143,9 @@ py_test( size = "small", srcs = ["filter_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -152,21 +157,28 @@ py_test( ], ) -py_test( +tf_py_test( name = "flat_map_dataset_op_test", size = "small", srcs = ["flat_map_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ + ":dataset_serialization_test", + "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:session", "//tensorflow/python:training", - "//third_party/py/numpy", + "//tensorflow/python:variable_scope", ], + grpc_enabled = True, + tags = ["no_pip"], ) py_test( @@ -178,6 +190,7 @@ py_test( "manual", # b/67958761 ], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", @@ -187,18 +200,18 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:math_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//third_party/py/numpy", ], ) -py_test( +tf_py_test( name = "iterator_ops_cluster_test", size = "small", srcs = ["iterator_ops_cluster_test.py"], - srcs_version = "PY2AND3", - tags = ["no_windows"], - deps = [ + additional_deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -212,14 +225,19 @@ py_test( "//tensorflow/python:session", "//tensorflow/python/data/ops:iterator_ops", ], + grpc_enabled = True, + tags = [ + "no_windows", + "oss_serial", + ], ) -py_test( +tf_py_test( name = "iterator_ops_test", size = "small", srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ + "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", @@ -241,8 +259,8 @@ py_test( "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", ], + grpc_enabled = True, ) py_test( @@ -262,12 +280,13 @@ py_test( py_test( name = "map_dataset_op_test", - size = "small", + size = "medium", srcs = ["map_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -276,20 +295,35 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:function", "//tensorflow/python:functional_ops", "//tensorflow/python:io_ops", "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", - "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//third_party/py/numpy", ], ) +py_test( + name = "prefetch_dataset_op_test", + size = "small", + srcs = ["prefetch_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/python:platform", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "range_dataset_op_test", size = "small", @@ -318,25 +352,22 @@ py_test( py_test( name = "reader_dataset_ops_test", - size = "small", + size = "medium", srcs = ["reader_dataset_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", ], @@ -361,11 +392,30 @@ py_test( ) py_test( - name = "sequence_dataset_op_test", + name = "scan_dataset_op_test", size = "small", + srcs = ["scan_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "sequence_dataset_op_test", + size = "medium", srcs = ["sequence_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -389,12 +439,15 @@ py_test( py_test( name = "shuffle_dataset_op_test", - size = "small", + size = "medium", srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -423,21 +476,32 @@ py_test( ], ) +py_test( + name = "stats_dataset_ops_test", + size = "small", + srcs = ["stats_dataset_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + ], +) + py_test( name = "zip_dataset_op_test", size = "small", srcs = ["zip_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], ) @@ -449,23 +513,29 @@ py_test( srcs_version = "PY2AND3", tags = [ "manual", - "no_oss", + "no_oss", # b/68785503 ], deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", ], ) filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 670f622c3c372dd08870390298f2e28db7e85596..506eefbef0204284a103827180c13b13200a3f93 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -52,8 +52,9 @@ class BatchDatasetTest(test.TestCase): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count).batch(batch_size).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(count).batch(batch_size).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -69,7 +70,7 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(14): - self.assertAllEqual(component[(i*14 + j) % 7]**2, + self.assertAllEqual(component[(i * 14 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -84,12 +85,12 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(8): - self.assertAllEqual(component[(i*8 + j) % 7]**2, + self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) result = sess.run(get_next) for component, result_component in zip(components, result): for j in range((14 * 7) % 8): - self.assertAllEqual(component[((num_batches - 1)*8 + j) % 7]**2, + self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -103,14 +104,67 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testBatchSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).batch( + 5).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(2): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testNestedBatchSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(5).batch( + 2).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], + [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]], + values=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + dense_shape=[2, 5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testPaddedBatchDataset(self): seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) - iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens) - .map(lambda x: array_ops.fill([x], x)).padded_batch( - 4, - padded_shapes=padded_shape).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(seq_lens) + .map(lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=padded_shape).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -118,35 +172,40 @@ class BatchDatasetTest(test.TestCase): with self.test_session() as sess: # Test with random sequence lengths, and max padding. random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [-1], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) padded_len = np.max(result) self.assertEqual((4, padded_len), result.shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test with random sequence lengths, and constant padding. - sess.run(init_op, feed_dict={padded_shape: [25], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [25], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) self.assertEqual((4, 25), result.shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test correct handling of empty tensors. - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: [0, 0, 0, 0]}) + sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]}) result = sess.run(get_next) self.assertAllEqual([[], [], [], []], result) with self.assertRaises(errors.OutOfRangeError): @@ -154,8 +213,7 @@ class BatchDatasetTest(test.TestCase): # Test error handling with constant sequence lengths, and # too-short padding. - sess.run(init_op, feed_dict={padded_shape: [5], - seq_lens: [6, 5, 5, 5]}) + sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]}) with self.assertRaises(errors.DataLossError): result = sess.run(get_next) @@ -166,11 +224,13 @@ class BatchDatasetTest(test.TestCase): def fill_tuple(x): filled = array_ops.fill([x], x) return (filled, string_ops.as_string(filled)) - iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple) - .padded_batch( - 4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, "")).make_initializable_iterator()) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple) + .padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, "")).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -178,15 +238,18 @@ class BatchDatasetTest(test.TestCase): with self.test_session() as sess: # Test with random sequence lengths, and max padding. random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [-1], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) padded_len = np.max(result[0]) self.assertEqual((4, padded_len), result[0].shape) self.assertEqual((4, padded_len), result[1].shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[0][j, seq_len:], [-1] * (padded_len - seq_len)) @@ -220,20 +283,30 @@ class BatchDatasetTest(test.TestCase): constant_op.constant([-1, -1], dtype=dtypes.int64), constant_op.constant([37], dtype=dtypes.int64))) - for dataset in [dynamic_padding_from_tensor_shapes, - dynamic_padding_from_lists, - dynamic_padding_from_lists_with_minus_one, - dynamic_padding_from_tensors]: + for dataset in [ + dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists, + dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors + ]: self.assertEqual([None, None], dataset.output_shapes[0].as_list()) self.assertEqual([None, None, None], dataset.output_shapes[1].as_list()) self.assertEqual([None, 37], dataset.output_shapes[2].as_list()) + def testPaddedBatchSparseError(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + with self.assertRaises(TypeError): + _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10) + def testDenseToSparseBatchDataset(self): components = np.random.randint(12, size=(100,)).astype(np.int32) - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([x], x)).apply( - batching.dense_to_sparse_batch(4, [12])) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, + [12])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -242,24 +315,26 @@ class BatchDatasetTest(test.TestCase): for start in range(0, len(components), 4): results = sess.run(get_next) + self.assertAllEqual([[i, j] + for i, c in enumerate(components[start:start + 4]) + for j in range(c)], results.indices) self.assertAllEqual( - [[i, j] for i, c in enumerate(components[start:start+4]) - for j in range(c)], results.indices) - self.assertAllEqual( - [c for c in components[start:start+4] for _ in range(c)], + [c for c in components[start:start + 4] for _ in range(c)], results.values) - self.assertAllEqual( - [min(4, len(components) - start), 12], results.dense_shape) + self.assertAllEqual([min(4, + len(components) - start), 12], + results.dense_shape) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testDenseToSparseBatchDatasetWithUnknownShape(self): components = np.random.randint(5, size=(40,)).astype(np.int32) - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([x, x], x)).apply( - batching.dense_to_sparse_batch( - 4, [5, -1])).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.fill([x, x], x)).apply( + batching.dense_to_sparse_batch( + 4, [5, -1])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -268,27 +343,30 @@ class BatchDatasetTest(test.TestCase): for start in range(0, len(components), 4): results = sess.run(get_next) - self.assertAllEqual( - [[i, j, z] for i, c in enumerate(components[start:start+4]) - for j in range(c) for z in range(c)], results.indices) - self.assertAllEqual( - [c for c in components[start:start+4] - for _ in range(c) for _ in range(c)], - results.values) - self.assertAllEqual( - [min(4, len(components) - start), - 5, - np.max(components[start:start+4])], - results.dense_shape) + self.assertAllEqual([[i, j, z] + for i, c in enumerate(components[start:start + 4]) + for j in range(c) + for z in range(c)], results.indices) + self.assertAllEqual([ + c + for c in components[start:start + 4] for _ in range(c) + for _ in range(c) + ], results.values) + self.assertAllEqual([ + min(4, + len(components) - start), 5, + np.max(components[start:start + 4]) + ], results.dense_shape) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testDenseToSparseBatchDatasetWithInvalidShape(self): input_tensor = array_ops.constant([[1]]) - iterator = (dataset_ops.Dataset.from_tensors(input_tensor) - .apply(batching.dense_to_sparse_batch(4, [-2])) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensors(input_tensor).apply( + batching.dense_to_sparse_batch(4, [-2])) + .make_initializable_iterator()) init_op = iterator.initializer with self.test_session() as sess: @@ -298,8 +376,10 @@ class BatchDatasetTest(test.TestCase): def testDenseToSparseBatchDatasetShapeErrors(self): input_tensor = array_ops.placeholder(dtypes.int32) - iterator = (dataset_ops.Dataset.from_tensors(input_tensor).apply( - batching.dense_to_sparse_batch(4, [12])).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensors(input_tensor).apply( + batching.dense_to_sparse_batch(4, + [12])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -356,8 +436,7 @@ class BatchDatasetTest(test.TestCase): def testUnbatchMultiElementTupleDataset(self): data = tuple([(math_ops.range(10 * i, 10 * i + 10), - array_ops.fill([10], "hi")) - for i in range(3)]) + array_ops.fill([10], "hi")) for i in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) expected_types = ((dtypes.int32, dtypes.string),) * 3 data = data.batch(2) @@ -370,9 +449,7 @@ class BatchDatasetTest(test.TestCase): with self.test_session() as sess: for i in range(10): - self.assertEqual(((i, b"hi"), - (10 + i, b"hi"), - (20 + i, b"hi")), + self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), sess.run(op)) with self.assertRaises(errors.OutOfRangeError): @@ -385,9 +462,10 @@ class BatchDatasetTest(test.TestCase): batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(batch_size)) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(batch_size)) + .make_initializable_iterator()) next_element = iterator.get_next() @@ -404,14 +482,85 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testBatchAndDropRemainderSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(12).map(_sparse).apply( + batching.batch_and_drop_remainder(5)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(2): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testPaddedBatchAndDropRemainder(self): + els = [] + for length in [3, 6, 9, 4, 12, 10, 2]: + els.append((np.array(length), np.arange(length) + 1, + np.array(length * 2))) + + dataset = dataset_ops.Dataset.from_tensors(els[0]) + for el in els[1:]: + dataset = dataset.concatenate(dataset_ops.Dataset.from_tensors(el)) + + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = ( + dataset.apply( + batching.padded_batch_and_drop_remainder( + batch_size, ([], [None], []))).make_initializable_iterator()) + + next_element = iterator.get_next() + + with self.test_session() as sess: + for test_batch_size in [1, 3, 7, 10]: + sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) + num_batches = 7 // test_batch_size + for i in range(num_batches): + result = sess.run(next_element) + for component_idx, result_component in enumerate(result): + for j in range(test_batch_size): + data_idx = i * test_batch_size + j + comp = result_component[j] + unpadded = comp[comp > 0] + if np.isscalar(comp): + # The boolean mask indexing above adds a dim back. Rm it. + unpadded = unpadded[0] + self.assertAllEqual(els[data_idx][component_idx], unpadded) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPaddedBatchAndDropRemainderSparseError(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + with self.assertRaises(TypeError): + _ = dataset_ops.Dataset.range(10).map(_map_fn).apply( + batching.padded_batch_and_drop_remainder(5)) + def testBatchAndDropRemainderShapeInference(self): - components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder( - dtypes.int32, shape=[None]), array_ops.placeholder( - dtypes.int32, shape=[20, 30]))) + components = (array_ops.placeholder(dtypes.int32), + (array_ops.placeholder(dtypes.int32, shape=[None]), + array_ops.placeholder(dtypes.int32, shape=[20, 30]))) # Test with a statically known batch size. - dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(128))) + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(128))) self.assertIs(None, dataset.output_shapes[0].ndims) self.assertEqual([128], dataset.output_shapes[1][0].as_list()) @@ -420,14 +569,15 @@ class BatchDatasetTest(test.TestCase): # Test with a dynamic batch size: the static shape will be unknown, because # `batch_size` is a placeholder. batch_size = array_ops.placeholder(dtypes.int64) - dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(batch_size))) + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(batch_size))) self.assertIs(None, dataset.output_shapes[0].ndims) self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def testBatchAndMapDataset(self): + def _testBatchAndMapDatasetHelper(self, num_parallel_batches=1): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). @@ -441,9 +591,13 @@ class BatchDatasetTest(test.TestCase): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components).repeat(count) - .apply(batching.map_and_batch(_map_fn, batch_size)) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches)) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -459,7 +613,7 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(14): - self.assertAllEqual(component[(i*14 + j) % 7]**2, + self.assertAllEqual(component[(i * 14 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -474,9 +628,13 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(8): - self.assertAllEqual(component[(i*8 + j) % 7]**2, + self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) - # The last batch should fail with `OutOfRange`. + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range((14 * 7) % 8): + self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, + result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -489,14 +647,45 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + def testBatchAndMapDataset(self): + return self._testBatchAndMapDatasetHelper() + + def testBatchAndMapDatasetWithParallelBatching(self): + return self._testBatchAndMapDatasetHelper(num_parallel_batches=10) + + def testMapAndBatchSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).apply( + batching.map_and_batch(_sparse, 5)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(2): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testBatchAndMapDatasetFails(self): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = (dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) - .make_initializable_iterator()) + iterator = ( + dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + .make_initializable_iterator()) init_op = iterator.initializer with self.test_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): @@ -504,6 +693,7 @@ class BatchDatasetTest(test.TestCase): def testBatchAndMapDatasetShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" + def generator(): yield [1] yield [2] @@ -546,5 +736,41 @@ class BatchDatasetSerializationTest( num_outputs) +class PaddedBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testPaddedBatch(self): + + def build_dataset(seq_lens): + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=[-1]) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + def testPaddedBatchNonDefaultPadding(self): + + def build_dataset(seq_lens): + + def fill_tuple(x): + filled = array_ops.fill([x], x) + return (filled, string_ops.as_string(filled)) + + padded_shape = [-1] + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + fill_tuple).padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, "")) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index c3d6bfc097798530008f186cce68906b6af8fe47..55a1d3b95b212466b262ad3c26f1efd7ed0e067e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -17,14 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import threading import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import iterator_ops from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.util import nest @@ -32,16 +31,16 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib class DatasetConstructorTest(test.TestCase): - def testTensorDataset(self): + def testFromTensors(self): """Test an dataset that represents a single tuple of tensors.""" components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) @@ -61,7 +60,75 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testTensorSliceDataset(self): + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testFromTensorsSparse(self): + """Test an dataset that represents a single tuple of tensors.""" + components = (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1]]), + values=np.array([-1, 1]), + dense_shape=np.array([2, 2]))) + + iterator = ( + dataset_ops.Dataset.from_tensors(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual( + [tensor_shape.TensorShape(c.dense_shape) for c in components], + [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + results = sess.run(get_next) + for component, result_component in zip(components, results): + self.assertSparseValuesEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorsMixed(self): + """Test an dataset that represents a single tuple of tensors.""" + components = (np.array(1), np.array([1, 2, 3]), np.array(37.0), + sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1]]), + values=np.array([-1, 1]), + dense_shape=np.array([2, 2]))) + + iterator = ( + dataset_ops.Dataset.from_tensors(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([ + tensor_shape.TensorShape(c.dense_shape) + if sparse_tensor.is_sparse(c) else c.shape for c in components + ], [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + results = sess.run(get_next) + for component, result_component in zip(components, results): + if sparse_tensor.is_sparse(component): + self.assertSparseValuesEqual(component, result_component) + else: + self.assertAllEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorSlices(self): """Test an dataset that represents the slices from a tuple of tensors.""" components = ( np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( @@ -86,7 +153,127 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testTensorSliceDatasetWithDict(self): + def testFromTensorSlicesSparse(self): + """Test an dataset that represents the slices from a tuple of tensors.""" + components = (sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 0], [2, 0]]), + values=np.array([0, 0, 0]), + dense_shape=np.array([3, 1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1], [2, 2]]), + values=np.array([1, 2, 3]), + dense_shape=np.array([3, 3]))) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual( + [tensor_shape.TensorShape(c.dense_shape[1:]) for c in components], + [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + expected = [ + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([1]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[1]]), + values=np.array([2]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[2]]), + values=np.array([3]), + dense_shape=np.array([3]))), + ] + for i in range(3): + results = sess.run(get_next) + for component, result_component in zip(expected[i], results): + self.assertSparseValuesEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorSlicesMixed(self): + """Test an dataset that represents the slices from a tuple of tensors.""" + components = (np.tile(np.array([[1], [2], [3]]), 20), + np.tile(np.array([[12], [13], [14]]), 22), + np.array([37.0, 38.0, 39.0]), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 0], [2, 0]]), + values=np.array([0, 0, 0]), + dense_shape=np.array([3, 1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1], [2, 2]]), + values=np.array([1, 2, 3]), + dense_shape=np.array([3, 3]))) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([ + tensor_shape.TensorShape(c.dense_shape[1:]) + if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components + ], [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + expected = [ + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([1]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[1]]), + values=np.array([2]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[2]]), + values=np.array([3]), + dense_shape=np.array([3]))), + ] + for i in range(3): + results = sess.run(get_next) + for component, result_component in zip( + (zip(*components[:3])[i] + expected[i]), results): + if sparse_tensor.is_sparse(component): + self.assertSparseValuesEqual(component, result_component) + else: + self.assertAllEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorSlicesWithDict(self): components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} iterator = (dataset_ops.Dataset.from_tensor_slices(components) .make_initializable_iterator()) @@ -107,7 +294,7 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testSparseTensorSliceDataset(self): + def testFromSparseTensorSlices(self): """Test a dataset based on slices of a `tf.SparseTensor`.""" st = array_ops.sparse_placeholder(dtypes.float64) iterator = (dataset_ops.Dataset.from_sparse_tensor_slices(st) @@ -574,135 +761,63 @@ class DatasetConstructorTest(test.TestCase): new = batching._RestructuredDataset(dataset, new_types, new_shape_lists) # pylint: enable=protected-access - def _iterator_checkpoint_prefix(self): - return os.path.join(self.get_temp_dir(), "iterator") - def _testSaveRestoreFromTensorsUtility(self, start, break_range, stop): - path = self._iterator_checkpoint_prefix() - step = 0 - meta_filename = path + "-%d.meta" % step +class DatasetConstructorSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) + def _build_tensor_dataset(self, variable_array): + components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) - with ops.Graph().as_default() as g: - iterator = ( - dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - saveable = iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - for t in nest.flatten(get_next): - ops.add_to_collection("get_next", t) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(start, break_range): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component, result_component) - saver.save(sess, path, step) - - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - with self.test_session(graph=g) as sess: - get_next = nest.pack_sequence_as(("a", "b", "c"), - ops.get_collection("get_next")) - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for _ in range(break_range, stop): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + return dataset_ops.Dataset.from_tensors(components) - def testRestoreFromTensors(self): - self._testSaveRestoreFromTensorsUtility(0, 0, 1) + def testFromTensorsCore(self): + # Equal length components + arr = np.array(1) + num_outputs = 1 + diff_arr = np.array(2) + self.run_core_tests(lambda: self._build_tensor_dataset(arr), + lambda: self._build_tensor_dataset(diff_arr), + num_outputs) - def testRestoreExhuatedIteratorFromTensors(self): - self._testSaveRestoreFromTensorsUtility(0, 1, 1) + def _build_tensor_slices_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components) - def _build_graph_tensor_slices(self, components): - iterator = dataset_ops.Dataset.from_tensor_slices( - components).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - saveable = iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - for t in nest.flatten(get_next): - ops.add_to_collection("get_next", t) - return init_op, get_next - - def _testSaveRestoreFromTensorSlicesUtility(self, start, break_range, stop): - path = self._iterator_checkpoint_prefix() - step = 0 - meta_filename = path + "-%d.meta" % step - - components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( - np.array([[12], [13], [14], [15]]), 22), + def testFromTensorSlicesCore(self): + # Equal length components + components = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), np.array([37.0, 38.0, 39.0, 40.0])) - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph_tensor_slices(components) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for i in range(start, break_range): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i], result_component) - saver.save(sess, path, step) - - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - with self.test_session(graph=g) as sess: - get_next = nest.pack_sequence_as(("a", "b", "c"), - ops.get_collection("get_next")) - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(break_range, stop): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRestoreFromTensorSlices(self): - self._testSaveRestoreFromTensorSlicesUtility(0, 4, 2) - - def testRestoreExhaustedIteratorFromTensorSlices(self): - self._testSaveRestoreFromTensorSlicesUtility(0, 4, 4) - - def tesRestoreFromTensorSlicesWithDict(self): - - path = self._iterator_checkpoint_prefix() - step = 0 - meta_filename = path + "-%d.meta" % step - - components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} - - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph_tensor_slices(components) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for i in range(2): - results = sess.run(get_next) - self.assertEqual(components["foo"][i], results["foo"]) - self.assertEqual(components["bar"][i], results["bar"]) - saver.save(sess, path, step) - - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - with self.test_session(graph=g) as sess: - get_next = nest.pack_sequence_as(("a", "b"), - ops.get_collection("get_next")) - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(2, 3): - results = sess.run(get_next) - self.assertEqual(components["foo"][i], results["foo"]) - self.assertEqual(components["bar"][i], results["bar"]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[5], [6], [7], [8]]), 22), + np.array([1.0, 2.0, 3.0, 4.0])) + + dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} + + self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), + lambda: self._build_tensor_slices_dataset(diff_comp), 4) + self.run_core_tests( + lambda: self._build_tensor_slices_dataset(dict_components), None, 3) + + def _build_sparse_tensor_slice_dataset(self, slices): + indices = np.array( + [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], + dtype=np.int64) + values = np.array([val for s in slices for val in s], dtype=np.float64) + dense_shape = np.array( + [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) + sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) + return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) + + def testFromSparseTensorSlicesCore(self): + slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] + diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] + + self.run_core_tests( + lambda: self._build_sparse_tensor_slice_dataset(slices), + lambda: self._build_sparse_tensor_slice_dataset(diff_slices), + 9, + sparse_tensors=True) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index df9147af6c03925ac9f372c561000eaa6e7f328e..bf25cc60a1c0efc09bed6501fd2d6f4ccb07764b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -23,8 +23,11 @@ import os import numpy as np from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.training import saver as saver_lib @@ -32,12 +35,12 @@ from tensorflow.python.util import nest class DatasetSerializationTestBase(test.TestCase): - """Base class for testing finite serializable datasets.""" + """Base class for testing serializable datasets.""" def tearDown(self): self._delete_ckpt() - def run_core_tests(self, ds_fn1, ds_fn2, num_outputs): + def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): """Runs the core tests. Args: @@ -45,32 +48,53 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn2: 0-argument function that returns a Dataset different from ds_fn1. If None, verify_restore_in_modified_graph test is not run. num_outputs: Total number of outputs expected from this Dataset. + sparse_tensors: Whether dataset is built from SparseTensor(s). Raises: AssertionError if any test fails. """ - self.verify_unused_iterator(ds_fn1, num_outputs) - self.verify_fully_used_iterator(ds_fn1, num_outputs) - self.verify_exhausted_iterator(ds_fn1, num_outputs) - self.verify_init_before_restore(ds_fn1, num_outputs) - self.verify_multiple_breaks(ds_fn1, num_outputs) - self.verify_reset_restored_iterator(ds_fn1, num_outputs) + self.verify_unused_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_fully_used_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_exhausted_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_init_before_restore( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_multiple_breaks( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_reset_restored_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_restore_in_empty_graph( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) if ds_fn2: - self.verify_restore_in_modified_graph(ds_fn1, ds_fn2, num_outputs) + self.verify_restore_in_modified_graph( + ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors) - def verify_unused_iterator(self, ds_fn, num_outputs): + def verify_unused_iterator(self, + ds_fn, + num_outputs, + sparse_tensors=False, + verify_exhausted=True): """Verifies that saving and restoring an unused iterator works. Args: ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. """ - self.verify_run_with_breaks(ds_fn, [0], num_outputs) + self.verify_run_with_breaks( + ds_fn, [0], + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) - def verify_fully_used_iterator(self, ds_fn, num_outputs): + def verify_fully_used_iterator(self, ds_fn, num_outputs, + sparse_tensors=False): """Verifies that saving and restoring a fully used iterator works. Note that this only checks saving and restoring an iterator from which @@ -81,13 +105,15 @@ class DatasetSerializationTestBase(test.TestCase): Args: ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. Raises: AssertionError if test fails. """ - self.verify_run_with_breaks(ds_fn, [num_outputs], num_outputs) + self.verify_run_with_breaks( + ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors) - def verify_exhausted_iterator(self, ds_fn, num_outputs): + def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False): """Verifies that saving and restoring an exhausted iterator works. An exhausted iterator is one which has returned an OutOfRange error. @@ -95,21 +121,36 @@ class DatasetSerializationTestBase(test.TestCase): Args: ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. Raises: AssertionError if any test fails. """ - self.gen_outputs(ds_fn, [], num_outputs, verify_exhausted=True) + self.gen_outputs( + ds_fn, [], + num_outputs, + verify_exhausted=True, + sparse_tensors=sparse_tensors) actual = self.gen_outputs( - ds_fn, [], 0, ckpt_saved=True, verify_exhausted=True) + ds_fn, [], + 0, + ckpt_saved=True, + verify_exhausted=True, + sparse_tensors=sparse_tensors) self.assertEqual(len(actual), 0) - def verify_init_before_restore(self, ds_fn, num_outputs): - """Verifies that retoring into an already initilized iterator works. + def verify_init_before_restore(self, + ds_fn, + num_outputs, + sparse_tensors=False, + verify_exhausted=True): + """Verifies that restoring into an already initilized iterator works. Args: ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. @@ -118,9 +159,16 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn, self.gen_break_points(num_outputs), num_outputs, - init_before_restore=True) + init_before_restore=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) - def verify_multiple_breaks(self, ds_fn, num_outputs, num_breaks=10): + def verify_multiple_breaks(self, + ds_fn, + num_outputs, + num_breaks=10, + sparse_tensors=False, + verify_exhausted=True): """Attempts to save/restore at multiple break points. Args: @@ -128,16 +176,25 @@ class DatasetSerializationTestBase(test.TestCase): num_outputs: See `run_core_tests`. num_breaks: The number of break points. These are uniformly spread in [0, num_outputs] both inclusive. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. """ - self.verify_run_with_breaks(ds_fn, - self.gen_break_points(num_outputs, num_breaks), - num_outputs) - - def verify_reset_restored_iterator(self, ds_fn, num_outputs, - break_point=None): + self.verify_run_with_breaks( + ds_fn, + self.gen_break_points(num_outputs, num_breaks), + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + def verify_reset_restored_iterator(self, + ds_fn, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): """Attempts to re-initialize a restored iterator. This is useful when restoring a training checkpoint during validation. @@ -146,6 +203,8 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. @@ -153,30 +212,43 @@ class DatasetSerializationTestBase(test.TestCase): break_point = num_outputs // 2 if not break_point else break_point # Collect ground truth containing all outputs. - expected = self.gen_outputs(ds_fn, [], num_outputs, verify_exhausted=True) + expected = self.gen_outputs( + ds_fn, [], + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) # Skip some items and save checkpoint. - self.gen_outputs(ds_fn, [], break_point, verify_exhausted=False) + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) actual = [] # Restore from checkpoint and then run init_op. with ops.Graph().as_default() as g: saver = self._import_meta_graph() - init_op, get_next_op = self._get_iterator_ops_from_collection(ds_fn) + init_op, get_next_op = self._get_iterator_ops_from_collection( + ds_fn, sparse_tensors=sparse_tensors) with self.test_session(graph=g) as sess: self._restore(saver, sess) + sess.run(variables.global_variables_initializer()) sess.run(init_op) for _ in range(num_outputs): actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) self.match(expected, actual) def verify_restore_in_modified_graph(self, ds_fn1, ds_fn2, num_outputs, - break_point=None): + break_point=None, + sparse_tensors=False, + verify_exhausted=True): """Attempts to restore an iterator in a modified graph. Builds an input pipeline using ds_fn1, runs it for `break_point` steps @@ -188,6 +260,8 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn2: See `run_core_tests`. num_outputs: See `run_core_tests`. break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. @@ -196,34 +270,138 @@ class DatasetSerializationTestBase(test.TestCase): # Skip `break_point` items and store the remaining produced from ds_fn1 # in `expected`. - self.gen_outputs(ds_fn1, [], break_point) + self.gen_outputs( + ds_fn1, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) expected = self.gen_outputs( ds_fn1, [], num_outputs - break_point, ckpt_saved=True, - verify_exhausted=True) + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) # Generate `break_point` items from ds_fn1 and save checkpoint. - self.gen_outputs(ds_fn1, [], break_point) + self.gen_outputs( + ds_fn1, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) actual = [] # Build graph for ds_fn2 but load checkpoint for ds_fn1. with ops.Graph().as_default() as g: - _, get_next_op, saver = self._build_graph(ds_fn2) + _, get_next_op, saver = self._build_graph( + ds_fn2, sparse_tensors=sparse_tensors) with self.test_session(graph=g) as sess: self._restore(saver, sess) for _ in range(num_outputs - break_point): actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.match(expected, actual) + + def verify_restore_in_empty_graph(self, + ds_fn, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to restore an iterator in an empty graph. + + Builds an input pipeline using ds_fn, runs it for `break_point` steps + and saves a checkpoint. Then builds a new empty graph, restores + the checkpoint from ds_fn and verifies that the restore is successful. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Skip `break_point` items and store the remaining produced from ds_fn + # in `expected`. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + expected = self.gen_outputs( + ds_fn, [], + num_outputs - break_point, + ckpt_saved=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + # Generate `break_point` items from ds_fn and save checkpoint. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + + actual = [] + # Build an empty graph but load checkpoint for ds_fn. + with ops.Graph().as_default() as g: + get_next_op, saver = self._build_empty_graph( + ds_fn, sparse_tensors=sparse_tensors) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) self.match(expected, actual) + def verify_error_on_save(self, + ds_fn, + num_outputs, + error, + break_point=None, + sparse_tensors=False): + """Attempts to save a non-saveable iterator. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + error: Declared error when trying to save iterator. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + + Raises: + AssertionError if any test fails. + """ + + break_point = num_outputs // 2 if not break_point else break_point + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph( + ds_fn, sparse_tensors=sparse_tensors) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + with self.assertRaises(error): + self._save(sess, saver) + def verify_run_with_breaks(self, ds_fn, break_points, num_outputs, - init_before_restore=False): + init_before_restore=False, + sparse_tensors=False, + verify_exhausted=True): """Verifies that ds_fn() produces the same outputs with and without breaks. 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it @@ -238,6 +416,8 @@ class DatasetSerializationTestBase(test.TestCase): break_points: See `gen_outputs`. num_outputs: See `gen_outputs`. init_before_restore: See `gen_outputs`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. @@ -245,14 +425,18 @@ class DatasetSerializationTestBase(test.TestCase): expected = self.gen_outputs( ds_fn, [], num_outputs, - verify_exhausted=True, - init_before_restore=init_before_restore) + init_before_restore=init_before_restore, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + actual = self.gen_outputs( ds_fn, break_points, num_outputs, - verify_exhausted=True, - init_before_restore=init_before_restore) + init_before_restore=init_before_restore, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + self.match(expected, actual) def gen_outputs(self, @@ -261,7 +445,8 @@ class DatasetSerializationTestBase(test.TestCase): num_outputs, ckpt_saved=False, init_before_restore=False, - verify_exhausted=False): + sparse_tensors=False, + verify_exhausted=True): """Generates elements from input dataset while stopping at break points. Produces `num_outputs` outputs and saves the state of the iterator in the @@ -281,20 +466,23 @@ class DatasetSerializationTestBase(test.TestCase): init_before_restore: Whether init should be called before saver.restore. This is just so that we can verify that restoring an already initialized iterator works. + sparse_tensors: Whether dataset is built from SparseTensor(s). verify_exhausted: Whether to verify that the iterator has been exhausted after producing `num_outputs` elements. Returns: - A list if `num_outputs` items. + A list of `num_outputs` items. """ outputs = [] def get_ops(): if ckpt_saved: saver = self._import_meta_graph() - init_op, get_next_op = self._get_iterator_ops_from_collection(ds_fn) + init_op, get_next_op = self._get_iterator_ops_from_collection( + ds_fn, sparse_tensors=sparse_tensors) else: - init_op, get_next_op, saver = self._build_graph(ds_fn) + init_op, get_next_op, saver = self._build_graph( + ds_fn, sparse_tensors=sparse_tensors) return init_op, get_next_op, saver for i in range(len(break_points) + 1): @@ -303,20 +491,22 @@ class DatasetSerializationTestBase(test.TestCase): with self.test_session(graph=g) as sess: if ckpt_saved: if init_before_restore: + sess.run(variables.global_variables_initializer()) sess.run(init_op) self._restore(saver, sess) else: + sess.run(variables.global_variables_initializer()) sess.run(init_op) start = break_points[i - 1] if i > 0 else 0 end = break_points[i] if i < len(break_points) else num_outputs num_iters = end - start for _ in range(num_iters): outputs.append(sess.run(get_next_op)) - self._save(sess, saver) - ckpt_saved = True if i == len(break_points) and verify_exhausted: with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) + self._save(sess, saver) + ckpt_saved = True return outputs @@ -343,7 +533,7 @@ class DatasetSerializationTestBase(test.TestCase): if nest.is_sequence(expected): self.assertEqual(len(expected), len(actual)) if isinstance(expected, dict): - for key1, key2 in sorted(expected, actual): + for key1, key2 in zip(sorted(expected), sorted(actual)): self.assertEqual(key1, key2) self.match(expected[key1], actual[key2]) else: @@ -360,34 +550,65 @@ class DatasetSerializationTestBase(test.TestCase): """Generates `num_samples` breaks points in [0, num_outputs].""" return np.linspace(0, num_outputs, num_samples, dtype=int) - def _build_graph(self, ds_fn): + def _build_graph(self, ds_fn, sparse_tensors=False): iterator = ds_fn().make_initializable_iterator() saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) init_op = iterator.initializer - get_next = iterator.get_next() - self._add_iterator_ops_to_collection(init_op, get_next) + if sparse_tensors: + get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + else: + get_next = iterator.get_next() + self._add_iterator_ops_to_collection(init_op, get_next, sparse_tensors) saver = saver_lib.Saver(allow_empty=True) return init_op, get_next, saver - def _add_iterator_ops_to_collection(self, init_op, get_next): + def _build_empty_graph(self, ds_fn, sparse_tensors=False): + iterator = iterator_ops.Iterator.from_structure( + self._get_output_types(ds_fn), self._get_output_shapes(ds_fn)) + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + if sparse_tensors: + get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + else: + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return get_next, saver + + def _add_iterator_ops_to_collection(self, + init_op, + get_next, + sparse_tensors=False): ops.add_to_collection("iterator_ops", init_op) # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections # do not support tuples we flatten the tensors and restore the shape in # `_get_iterator_ops_from_collection`. - for el in nest.flatten(get_next): - ops.add_to_collection("iterator_ops", el) + if sparse_tensors: + ops.add_to_collection("iterator_ops", get_next.indices) + ops.add_to_collection("iterator_ops", get_next.values) + ops.add_to_collection("iterator_ops", get_next.dense_shape) + else: + for el in nest.flatten(get_next): + ops.add_to_collection("iterator_ops", el) - def _get_iterator_ops_from_collection(self, ds_fn): + def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): all_ops = ops.get_collection("iterator_ops") - return all_ops[0], nest.pack_sequence_as( - self._get_output_types(ds_fn), all_ops[1:]) + if sparse_tensors: + init_op, indices, values, dense_shape = all_ops + return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) + else: + return all_ops[0], nest.pack_sequence_as( + self._get_output_types(ds_fn), all_ops[1:]) def _get_output_types(self, ds_fn): with ops.Graph().as_default(): return ds_fn().output_types + def _get_output_shapes(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_shapes + def _ckpt_path(self): return os.path.join(self.get_temp_dir(), "iterator") diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py index 00323da3110bb7f32b589f72e4e867f9c71e92ee..5921be2ae89ba1bbbb8d6e3a509cf49c65949544 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -19,9 +19,11 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops @@ -124,6 +126,74 @@ class FilterDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSparse(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])), i + + def _filter_fn(_, i): + return math_ops.equal(i % 2, 0) + + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( + lambda x, i: x).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(5): + actual = sess.run(get_next) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +class FilterDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_filter_range_graph(self, div): + return dataset_ops.Dataset.range(100).filter( + lambda x: math_ops.not_equal(math_ops.mod(x, div), 2)) + + def testFilterCore(self): + div = 3 + num_outputs = np.sum([x % 3 is not 2 for x in range(100)]) + self.run_core_tests(lambda: self._build_filter_range_graph(div), + lambda: self._build_filter_range_graph(div * 2), + num_outputs) + + def _build_filter_dict_graph(self): + return dataset_ops.Dataset.range(10).map( + lambda x: {"foo": x * 2, "bar": x ** 2}).filter( + lambda d: math_ops.equal(d["bar"] % 2, 0)).map( + lambda d: d["foo"] + d["bar"]) + + def testFilterDictCore(self): + num_outputs = np.sum([(x**2) % 2 == 0 for x in range(10)]) + self.run_core_tests(self._build_filter_dict_graph, None, num_outputs) + + def _build_sparse_filter(self): + + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + def _filter_fn(_, i): + return math_ops.equal(i % 2, 0) + + return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( + lambda x, i: x) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py index 2a582ae6620ac8276d290c7b995588640e36929c..d4fbaa5cdcdd315aa0524134b48eb0515169722c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py @@ -17,16 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools import random import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.client import session +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops +from tensorflow.python.framework import function +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -123,154 +129,101 @@ class FlatMapDatasetTest(test.TestCase): sess.run(get_next) # pylint: enable=g-long-lambda + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) -class InterleaveDatasetTest(test.TestCase): - - def _interleave(self, lists, cycle_length, block_length): - num_open = 0 - - # `all_iterators` acts as a queue of iterators over each element of `lists`. - all_iterators = [iter(l) for l in lists] - - # `open_iterators` are the iterators whose elements are currently being - # interleaved. - open_iterators = [] - for i in range(cycle_length): - if all_iterators: - open_iterators.append(all_iterators.pop(0)) - num_open += 1 - else: - open_iterators.append(None) - - while num_open or all_iterators: - for i in range(cycle_length): - if open_iterators[i] is None: - if all_iterators: - open_iterators[i] = all_iterators.pop(0) - num_open += 1 - else: - continue - for _ in range(block_length): - try: - yield next(open_iterators[i]) - except StopIteration: - open_iterators[i] = None - num_open -= 1 - break - - def testPythonImplementation(self): - input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], - [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] - - # Cycle length 1 acts like `Dataset.flat_map()`. - expected_elements = itertools.chain(*input_lists) - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 1, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1. - expected_elements = [4, 5, 4, 5, 4, 5, 4, - 5, 5, 6, 6, # NOTE(mrry): When we cycle back - # to a list and are already at - # the end of that list, we move - # on to the next element. - 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1 and block length > 1. - expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, - 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 3)): - self.assertEqual(expected, produced) - - # Cycle length > len(input_values). - expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, - 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 7, 2)): - self.assertEqual(expected, produced) - - def testInterleaveDataset(self): - input_values = array_ops.placeholder(dtypes.int64, shape=[None]) - cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) - block_length = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_count = 2 - - dataset = ( - dataset_ops.Dataset.from_tensor_slices(input_values) - .repeat(repeat_count) - .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length)) - iterator = dataset.make_initializable_iterator() + def _flat_map_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) + .make_initializable_iterator()) init_op = iterator.initializer - next_element = iterator.get_next() + get_next = iterator.get_next() with self.test_session() as sess: - # Cycle length 1 acts like `Dataset.flat_map()`. - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 1, block_length: 3}) - - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): - self.assertEqual(expected_element, sess.run(next_element)) - - # Cycle length > 1. - # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, - # 6, 5, 6, 5, 6, 5, 6, 5] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 1}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > 1 and block length > 1. - # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, - # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > len(input_values) * repeat_count. - # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, - # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 7, block_length: 2}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): - self.assertEqual(expected_element, sess.run(next_element)) + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + sess.run(get_next) - # Empty input. - sess.run(init_op, feed_dict={input_values: [], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - # Non-empty input leading to empty output. - sess.run(init_op, feed_dict={input_values: [0, 0, 0], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Mixture of non-empty and empty interleaved datasets. - sess.run(init_op, feed_dict={input_values: [4, 0, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) +class FlatMapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + # Complicated way of saying range(start, start+25). + def build_ds(start): + + def map_fn(x): + return dataset_ops.Dataset.range(x, x + 5) + + return dataset_ops.Dataset.range(start, start + 5 * 5, 5).flat_map(map_fn) + + self.run_core_tests(lambda: build_ds(0), lambda: build_ds(10), 25) + + def testMapThenFlatMap(self): + + def build_ds(): + + def flat_map_fn(_): + + def map_fn(y): + return 10 * math_ops.to_int32(y) + + return dataset_ops.Dataset.range(100).map(map_fn) + + return dataset_ops.Dataset.range(5).flat_map(flat_map_fn) + + self.run_core_tests(build_ds, None, 500) + + def testCaptureDefunInMapFn(self): + + def build_ds(): + + def map_fn(x): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.from_tensor_slices([defun_fn(x)]) + + return dataset_ops.Dataset.range(100).flat_map(map_fn) + + self.run_core_tests(build_ds, None, 100) + + def testDisallowVariableCapture(self): + + def build_ds(): + test_var = variable_scope.get_variable( + name="test_var", shape=(), use_resource=True) + return dataset_ops.Dataset.range(5).flat_map( + lambda _: dataset_ops.Dataset.from_tensor_slices([test_var])) + + self.verify_error_on_save(build_ds, 5, errors.InvalidArgumentError) + + def testDisallowCapturingStatefulOps(self): + + def build_ds(): + + def flat_map_fn(_): + + def map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(map_fn) + + return dataset_ops.Dataset.range(5).flat_map(flat_map_fn) + + self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 0aa9ea88de82b0851b0236d9412039d6573ab291..e66ed3f7aa2a512813ef353d2d0744ae67005884 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -22,18 +22,236 @@ import math import threading import time +import numpy as np from six.moves import zip_longest +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test +class InterleaveDatasetTest(test.TestCase): + + def _interleave(self, lists, cycle_length, block_length): + num_open = 0 + + # `all_iterators` acts as a queue of iterators over each element of `lists`. + all_iterators = [iter(l) for l in lists] + + # `open_iterators` are the iterators whose elements are currently being + # interleaved. + open_iterators = [] + for i in range(cycle_length): + if all_iterators: + open_iterators.append(all_iterators.pop(0)) + num_open += 1 + else: + open_iterators.append(None) + + while num_open or all_iterators: + for i in range(cycle_length): + if open_iterators[i] is None: + if all_iterators: + open_iterators[i] = all_iterators.pop(0) + num_open += 1 + else: + continue + for _ in range(block_length): + try: + yield next(open_iterators[i]) + except StopIteration: + open_iterators[i] = None + num_open -= 1 + break + + def testPythonImplementation(self): + input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], + [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] + + # Cycle length 1 acts like `Dataset.flat_map()`. + expected_elements = itertools.chain(*input_lists) + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 1, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1. + expected_elements = [4, 5, 4, 5, 4, 5, 4, + 5, 5, 6, 6, # NOTE(mrry): When we cycle back + # to a list and are already at + # the end of that list, we move + # on to the next element. + 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 2, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1 and block length > 1. + expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, + 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 2, 3)): + self.assertEqual(expected, produced) + + # Cycle length > len(input_values). + expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, + 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 7, 2)): + self.assertEqual(expected, produced) + + def testInterleaveDataset(self): + input_values = array_ops.placeholder(dtypes.int64, shape=[None]) + cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) + block_length = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_count = 2 + + dataset = ( + dataset_ops.Dataset.from_tensor_slices(input_values) + .repeat(repeat_count) + .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + next_element = iterator.get_next() + + with self.test_session() as sess: + # Cycle length 1 acts like `Dataset.flat_map()`. + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 1, block_length: 3}) + + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): + self.assertEqual(expected_element, sess.run(next_element)) + + # Cycle length > 1. + # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, + # 6, 5, 6, 5, 6, 5, 6, 5] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 2, block_length: 1}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Cycle length > 1 and block length > 1. + # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, + # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 2, block_length: 3}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Cycle length > len(input_values) * repeat_count. + # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, + # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 7, block_length: 2}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Empty input. + sess.run(init_op, feed_dict={input_values: [], + cycle_length: 2, block_length: 3}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Non-empty input leading to empty output. + sess.run(init_op, feed_dict={input_values: [0, 0, 0], + cycle_length: 2, block_length: 3}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Mixture of non-empty and empty interleaved datasets. + sess.run(init_op, feed_dict={input_values: [4, 0, 6], + cycle_length: 2, block_length: 3}) + for expected_element in self._interleave( + [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testSparse(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +class InterleaveDatasetSeriazationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, input_values, cycle_length, block_length): + repeat_count = 2 + return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + repeat_count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length) + + def testSerializationCore(self): + input_values = np.array([4, 5, 6], dtype=np.int64) + num_outputs = np.sum(input_values) * 2 + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + lambda: self._build_iterator_graph( + input_values, cycle_length * 2, block_length * 1), + num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # pylint: enable=g-long-lambda + + class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): @@ -547,5 +765,31 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testTooManyReadersSloppy(self): self._testTooManyReaders(sloppy=True) + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + dataset = dataset_ops.Dataset.range(10).map(_map_fn) + iterator = dataset.apply( + interleave_ops.parallel_interleave( + _interleave_fn, cycle_length=1)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 271d80a54b5a3e1a09cdf37e4f5e659fb67a78f9..bda9a2a4a37e9c3d35ff99041d1150ffc43f4c43 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -21,7 +21,6 @@ import os import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session @@ -34,6 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 8a1d99499be702d91f87f65f443261b47ce5c5cd..e9a07da84a8c80c09ebd4dab0b1d69febe1c9790 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -20,15 +20,18 @@ from collections import namedtuple import os import threading -from collections import namedtuple import numpy as np -from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import error_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops @@ -37,6 +40,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -616,6 +620,182 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + iterator = (dataset_ops.Dataset.range(10) + .map(_sparse) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + actual = sess.run(get_next) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, _sparse(i)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSparseChain(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def _check(i): + self.assertTrue(sparse_tensor.is_sparse(i)) + return sparse_ops.sparse_concat(0, [i, i]) + + iterator = ( + dataset_ops.Dataset.range(10).map(_sparse).map(_check) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + actual = sess.run(get_next) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval()) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testCaptureResourceInMapFn(self): + + def _build_ds(iterator): + + def _map_fn(x): + get_next = iterator.get_next() + return x * get_next + + return dataset_ops.Dataset.range(10).map(_map_fn) + + def _build_graph(): + captured_iterator = dataset_ops.Dataset.range( + 10).make_initializable_iterator() + ds = _build_ds(captured_iterator) + iterator = ds.make_initializable_iterator() + init_op = iterator.initializer + return captured_iterator.initializer, init_op + + with ops.Graph().as_default() as g: + captured_init_op, init_op = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(captured_init_op) + with self.assertRaises(errors.UnimplementedError): + # CapturedFunction does not support capturing IteratorResource. + sess.run(init_op) + + +class MapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 14 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(self._num_epochs)) + + def testSaveRestoreCore(self): + self.run_core_tests( + self._build_ds, + lambda: self._build_ds(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(_map_fn) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1))) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + +class IgnoreErrorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message")).apply( + error_ops.ignore_errors()) + + def testIgnoreErrorsCore(self): + components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) + diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) + num_outputs = 4 + self.run_core_tests(lambda: self._build_ds(components), + lambda: self._build_ds(diff_components), num_outputs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3d120a3071ef730f21221e3291d8c84385b51aa3 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class PrefetchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, seed): + return dataset_ops.Dataset.range(100).prefetch(10).shuffle( + buffer_size=10, seed=seed, reshuffle_each_iteration=False) + + def testCore(self): + num_outputs = 100 + self.run_core_tests(lambda: self.build_dataset(10), + lambda: self.build_dataset(20), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 329dc80ba5a29ade74ae8dfd12d37e5c1e2a9f73..8e6ad061a11752ab7b1ffc13c90b4fa52f67d6aa 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -19,9 +19,9 @@ from __future__ import print_function import os +from tensorflow.contrib.data.python.ops import counter from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -30,6 +30,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables @@ -194,6 +195,27 @@ class RangeDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testCounter(self): + """Test dataset construction using `count`.""" + iterator = (counter.Counter(start=3, step=4) + .make_one_shot_iterator()) + get_next = iterator.get_next() + self.assertEqual([], get_next.shape.as_list()) + self.assertEqual(dtypes.int64, get_next.dtype) + + negative_iterator = (counter.Counter(start=0, step=-1) + .make_one_shot_iterator()) + negative_get_next = negative_iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(3, sess.run(get_next)) + self.assertEqual(3 + 4, sess.run(get_next)) + self.assertEqual(3 + 2 * 4, sess.run(get_next)) + + self.assertEqual(0, sess.run(negative_get_next)) + self.assertEqual(-1, sess.run(negative_get_next)) + self.assertEqual(-2, sess.run(negative_get_next)) + def _iterator_checkpoint_prefix(self): return os.path.join(self.get_temp_dir(), "iterator") diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 8033f1d38806767ce08043d10c42dd376087765c..1c42a3d855bc16c21e385d7108c3106884ae4f5e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,8 +21,7 @@ import gzip import os import zlib -from tensorflow.contrib.data.python.ops import gen_dataset_ops -from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 @@ -31,17 +30,14 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops -from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import compat -class TextLineDatasetTest(test.TestCase): +class TextLineDatasetTestBase(test.TestCase): def _lineText(self, f, l): return compat.as_bytes("%d: %d" % (f, l)) @@ -79,6 +75,9 @@ class TextLineDatasetTest(test.TestCase): return filenames + +class TextLineDatasetTest(TextLineDatasetTestBase): + def _testTextLineDataset(self, compression_type=None): test_filenames = self._createFiles( 2, 5, crlf=True, compression_type=compression_type) @@ -165,282 +164,37 @@ class TextLineDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) - def _ckpt_path(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _latest_ckpt(self): - return saver_lib.latest_checkpoint(self.get_temp_dir()) - - def _save(self, saver, sess): - saver.save(sess, self._ckpt_path()) - - def _restore(self, saver, sess): - saver.restore(sess, self._latest_ckpt()) - def _import_meta_graph(self): - meta_file_path = self._ckpt_path() + ".meta" - return saver_lib.import_meta_graph(meta_file_path) +class TextLineDatasetSerializationTest( + TextLineDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): - def _build_graph(self, - test_filenames, - compression_type=None, - build_saveable=True): - ds = readers.TextLineDataset( + def _build_iterator_graph(self, test_filenames, compression_type=None): + return readers.TextLineDataset( test_filenames, compression_type=compression_type, buffer_size=10) - iterator = ds.make_initializable_iterator() - if build_saveable: - saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - init_op = iterator.initializer - get_next = iterator.get_next() - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next) - saver = saver_lib.Saver(allow_empty=True) - return init_op, get_next, saver - - def _testReadWithBreaks(self, breaks, num_files=5, lines_per_file=5): - """Tests reading from input pipeline with regular breaks. - - At each break point the iterator state gets saved using Saver and reloaded - in a new Graph and session. - - Args: - breaks: List of counts of records after reading which iterator state is - checkpointed. Must to in non-decreasing order. - num_files: Total number of files. - lines_per_file: Total number of lines per file. - """ + + def testTextLineCore(self): compression_types = [None, "GZIP", "ZLIB"] + num_files = 5 + lines_per_file = 5 + num_outputs = num_files * lines_per_file for compression_type in compression_types: test_filenames = self._createFiles( num_files, lines_per_file, crlf=True, compression_type=compression_type) + # pylint: disable=cell-var-from-loop + self.run_core_tests( + lambda: self._build_iterator_graph(test_filenames, compression_type), + lambda: self._build_iterator_graph(test_filenames), num_outputs) + # pylint: enable=cell-var-from-loop - # Collect ground truth. - total_records = num_files * lines_per_file - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type=compression_type) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(total_records): - expected_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Simulate run with breaks. - actual_records = [] - next_record_index = 0 - load_from_ckpt = False - breaks.append(total_records) - for break_index in breaks: - with ops.Graph().as_default() as g: - if not load_from_ckpt: - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type=compression_type) - else: - saver = self._import_meta_graph() - init_op, get_next = ops.get_collection("iterator_ops") - with self.test_session(graph=g) as sess: - if not load_from_ckpt: - sess.run(init_op) - else: - self._restore(saver, sess) - while next_record_index != break_index: - actual_records.append(sess.run(get_next)) - next_record_index += 1 - if break_index == total_records: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self._save(saver, sess) - load_from_ckpt = True - self.assertEqual(actual_records, expected_records) - - def testSaveAtFileBoundary(self): - self._testReadWithBreaks([10]) - - def testSaveWithinFile(self): - self._testReadWithBreaks([12]) - - def testSaveUnusedIterator(self): - self._testReadWithBreaks([0]) - - def testSaveRestoreIdempotence(self): - # Attempt to save an iterator immediately after it has been - # restored. - self._testReadWithBreaks([0, 0]) - self._testReadWithBreaks([10, 10]) - self._testReadWithBreaks([12, 12]) - - def testMultipleBreaks(self): - self._testReadWithBreaks([0, 4, 20]) - - def testRestoreExhaustedIterator(self): - num_files = 2 - lines_per_file = 5 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(num_files * lines_per_file): - sess.run(get_next) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self._save(saver, sess) - - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - saver = self._import_meta_graph() - self._restore(saver, sess) - _, get_next = ops.get_collection("iterator_ops") - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testInitThenRestore(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - sess.run(get_next) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - saver = self._import_meta_graph() - init_op, get_next = ops.get_collection("iterator_ops") - sess.run(init_op) - self._restore(saver, sess) - for _ in range(total_records - break_record): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - def testRestoreInModifiedGraph(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - sess.run(get_next) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type="GZIP") - self._restore(saver, sess) - for _ in range(total_records - break_record): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - def testRestoreInModifiedGraphThenInit(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - expected_records.append(sess.run(get_next)) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test that calling the init_op overrides the restored iterator. The - # iterator for the old graph was build to read uncompressed files and - # would fail when trying to read the new files. - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - test_filenames = self._createFiles( - num_files, lines_per_file, crlf=True, compression_type="GZIP") - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type="GZIP") - self._restore(saver, sess) - sess.run(init_op) - for _ in range(total_records): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - def testDoNotRestoreIterator(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - expected_records.append(sess.run(get_next)) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - init_op, get_next, saver = self._build_graph( - test_filenames, build_saveable=False) - self._restore(saver, sess) - with self.assertRaises(errors.FailedPreconditionError): - sess.run(get_next) - sess.run(init_op) - for _ in range(total_records): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - -class FixedLengthRecordReaderTest(test.TestCase): +class FixedLengthRecordReaderTestBase(test.TestCase): def setUp(self): - super(FixedLengthRecordReaderTest, self).setUp() + super(FixedLengthRecordReaderTestBase, self).setUp() self._num_files = 2 self._num_records = 7 self._header_bytes = 5 @@ -462,6 +216,9 @@ class FixedLengthRecordReaderTest(test.TestCase): f.write(b"F" * self._footer_bytes) return filenames + +class FixedLengthRecordReaderTest(FixedLengthRecordReaderTestBase): + def testFixedLengthRecordDataset(self): test_filenames = self._createFiles() filenames = array_ops.placeholder(dtypes.string, shape=[None]) @@ -547,304 +304,29 @@ class FixedLengthRecordReaderTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) - def _iterator_checkpoint_path(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(self, iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - self._iterator_checkpoint_path(), - parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(self, iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def _build_iterator_graph(self, num_epochs): + +class FixedLengthRecordDatasetSerializationTest( + FixedLengthRecordReaderTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, num_epochs, compression_type=None): filenames = self._createFiles() - dataset = (readers.FixedLengthRecordDataset( - filenames, self._record_bytes, self._header_bytes, self._footer_bytes) - .repeat(num_epochs)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next_op = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next_op, save_op, restore_op - - def _restore_iterator(self): - output_types = dtypes.string - output_shapes = tensor_shape.scalar() - iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) - get_next = iterator.get_next() - restore_op = self._restore_op(iterator._iterator_resource) - return restore_op, get_next - - def testSaveRestore(self): - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testInitThenRestore(self): - # Note: Calling init_op before restore_op is redundant. This test just makes - # sure we do not fail if restore is called on an already initialized - # iterator resource. - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreInModifiedGraph(self): - num_epochs = 10 - num_epochs_1 = 20 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs_1) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreWithoutBuildingDatasetGraph(self): - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - restore_op, get_next_op = self._restore_iterator() - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreUnusedIterator(self): - num_epochs = 10 - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - # Save unused iterator. - sess.run(save_op) - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for _ in range(num_epochs * self._num_files * self._num_records): - sess.run(get_next_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreExhaustedIterator(self): - num_epochs = 10 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for _ in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - -class TFRecordDatasetTest(test.TestCase): + return readers.FixedLengthRecordDataset( + filenames, self._record_bytes, self._header_bytes, + self._footer_bytes).repeat(num_epochs) + + def testFixedLengthRecordCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + +class TFRecordDatasetTestBase(test.TestCase): def setUp(self): - super(TFRecordDatasetTest, self).setUp() + super(TFRecordDatasetTestBase, self).setUp() self._num_files = 2 self._num_records = 7 @@ -880,6 +362,9 @@ class TFRecordDatasetTest(test.TestCase): writer.close() return filenames + +class TFRecordDatasetTest(TFRecordDatasetTestBase): + def testReadOneEpoch(self): with self.test_session() as sess: # Basic test: read from file 0. @@ -1001,6 +486,74 @@ class TFRecordDatasetTest(test.TestCase): sess.run(iterator.get_next()) +class TFRecordDatasetSerializationTest( + TFRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, + num_epochs, + batch_size=1, + compression_type=None, + buffer_size=None): + filenames = self._createFiles() + if compression_type is "ZLIB": + zlib_files = [] + for i, fn in enumerate(filenames): + with open(fn, "rb") as f: + cdata = zlib.compress(f.read()) + zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) + with open(zfn, "wb") as f: + f.write(cdata) + zlib_files.append(zfn) + filenames = zlib_files + + elif compression_type is "GZIP": + gzip_files = [] + for i, fn in enumerate(self.test_filenames): + with open(fn, "rb") as f: + gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) + with gzip.GzipFile(gzfn, "wb") as gzf: + gzf.write(f.read()) + gzip_files.append(gzfn) + filenames = gzip_files + + return readers.TFRecordDataset( + filenames, compression_type, + buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) + + def testTFRecordWithoutBufferCore(self): + num_epochs = 5 + batch_size = num_epochs + num_outputs = num_epochs * self._num_files * self._num_records // batch_size + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, batch_size, + buffer_size=0), + lambda: self._build_iterator_graph(num_epochs * 2, batch_size), + num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, + num_outputs * batch_size) + # pylint: enable=g-long-lambda + + def testTFRecordWithBufferCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + def testTFRecordWithCompressionCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + + class ReadBatchFeaturesTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py index 91615e9f6205cc95ff531b98683ff485964f714e..1a26da82e533ec01106ea10525c1cd96627c34fb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -207,5 +208,82 @@ class SequenceDatasetTest(test.TestCase): sess.run(get_next) +class SequenceDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_skip_dataset(self, count): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).skip(count) + + def testSkipFewerThanInputs(self): + count = 4 + num_outputs = 10 - count + self.run_core_tests(lambda: self._build_skip_dataset(count), + lambda: self._build_skip_dataset(count + 2), + num_outputs) + + def testSkipVarious(self): + # Skip more than inputs + self.run_core_tests(lambda: self._build_skip_dataset(20), None, 0) + # Skip exactly the input size + self.run_core_tests(lambda: self._build_skip_dataset(10), None, 0) + self.run_core_tests(lambda: self._build_skip_dataset(-1), None, 0) + # Skip nothing + self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10) + + def _build_take_dataset(self, count): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).take(count) + + def testTakeFewerThanInputs(self): + count = 4 + self.run_core_tests( + lambda: self._build_take_dataset(count), + lambda: self._build_take_dataset(count + 2), + count, + ) + + def testTakeVarious(self): + # Take more than inputs + self.run_core_tests(lambda: self._build_take_dataset(20), None, 10) + # Take exactly the input size + self.run_core_tests(lambda: self._build_take_dataset(10), None, 10) + # Take all + self.run_core_tests(lambda: self._build_take_dataset(-1), None, 10) + # Take nothing + self.run_core_tests(lambda: self._build_take_dataset(0), None, 0) + + def _build_repeat_dataset(self, count, take_count=3): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).take( + take_count).repeat(count) + + def testFiniteRepeat(self): + count = 10 + self.run_core_tests(lambda: self._build_repeat_dataset(count), + lambda: self._build_repeat_dataset(count + 2), + 3 * count) + + def testEmptyRepeat(self): + self.run_core_tests(lambda: self._build_repeat_dataset(0), None, 0) + + def testInfiniteRepeat(self): + self.verify_unused_iterator( + lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False) + self.verify_init_before_restore( + lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False) + self.verify_multiple_breaks( + lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False) + self.verify_reset_restored_iterator( + lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False) + self.verify_restore_in_modified_graph( + lambda: self._build_repeat_dataset(-1), + lambda: self._build_repeat_dataset(2), + 20, + verify_exhausted=False) + # Test repeat empty dataset + self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 6b5b53cc0f8f2d1df5622a5bc5e2f8ef04c6342a..ba1be0690ff3d72df9fe40980c0f5d53b33e41c5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -22,8 +22,10 @@ import os import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -156,6 +158,13 @@ class ShuffleDatasetTest(test.TestCase): for i in range(5): self.assertEqual(10, counts[i]) + def testSeedNoneSeed2NonNone(self): + with self.assertRaises(ValueError): + dataset_ops.ShuffleDataset(dataset_ops.Dataset.range(5), + buffer_size=1, + seed=None, + seed2=10) + class ShuffleDatasetSerializationTest(test.TestCase): @@ -474,5 +483,76 @@ class ShuffleDatasetSerializationTest(test.TestCase): self.assertEqual(expected_outputs_sorted, sorted(actual)) +class ShuffleAndRepeatTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed, count=5): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) + + def testCorrectOutput(self): + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertSequenceEqual( + sorted(output), sorted( + np.array([range(20) for _ in range(5)]).flatten())) + for i in range(5): + self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20)) + + def testReshuffling(self): + # Check that the output orders of different epochs are indeed different. + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + for i in range(4): + epoch1 = output[i * 20:(i + 1) * 20] + epoch2 = output[(i + 1) * 20:(i + 2) * 20] + self.assertNotEqual(epoch1, epoch2) + + def testSameOrderForSameSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertEqual(output1, output2) + + def testDifferentOrderForDifferentSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountNone(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountMinusOne(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testInfiniteOutputs(self): + # Asserting that the iterator is exhausted after producing 100 items should + # fail. + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100) + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100) + + +class ShuffleAndRepeatSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) + + def testCore(self): + self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), + 100) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8f24d6b2f612cff662aa8a36085bc69a9ea1a290 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import stats_ops +from tensorflow.core.framework import summary_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class StatsDatasetTest(test.TestCase): + + def _assertSummaryHasCount(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.num) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasSum(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.sum) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def testBytesProduced(self): + dataset = dataset_ops.Dataset.range(100).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + expected_sum = 0.0 + for i in range(100): + self.assertAllEqual( + np.array([i] * i, dtype=np.int64), sess.run(next_element)) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1)) + expected_sum += i * 8.0 + self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0) + self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) + + def testLatencyStats(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + + def testReinitialize(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(stats_aggregator_subscriber) + for j in range(5): + sess.run(iterator.initializer) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float((j * 100) + i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", (j + 1) * 100.0) + + def testNoAggregatorRegistered(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testMultipleTags(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.latency_stats("record_latency_2")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(i + 1)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency_2", float(i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency_2", 100.0) + + def testRepeatedTags(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(2 * (i + 1))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) + + def testMultipleIteratorsSameAggregator(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator_0 = dataset.make_initializable_iterator() + iterator_1 = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0), + stats_aggregator.subscribe(iterator_1)] + next_element = iterator_0.get_next() + iterator_1.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator_0.initializer, iterator_1.initializer, + stats_aggregator_subscribers]) + for i in range(100): + self.assertEqual(i * 2, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(2 * (i + 1))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) + + def testMultipleStatsAggregatorsSameIteratorFail(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator_0 = stats_ops.StatsAggregator() + stats_aggregator_1 = stats_ops.StatsAggregator() + + with self.test_session() as sess: + sess.run(stats_aggregator_0.subscribe(iterator)) + # TODO(mrry): Consider making this allowable (and also allowing + # aggregators to unsubscribe). + with self.assertRaises(errors.FailedPreconditionError): + sess.run(stats_aggregator_1.subscribe(iterator)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py index b0e72183019e4d53756542e2a2ef071111120dcd..5d34b0024c472d0393544ff3dad8acea7964345f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -110,5 +111,31 @@ class ZipDatasetTest(test.TestCase): sess.run(get_next) +class ZipDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, arr): + components = [ + np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), + np.array(arr) + ] + datasets = [ + dataset_ops.Dataset.from_tensor_slices(component) + for component in components + ] + return dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2]))) + + def testCore(self): + # Equal length components + arr = [37.0, 38.0, 39.0, 40.0] + num_outputs = len(arr) + self.run_core_tests(lambda: self._build_dataset(arr), None, num_outputs) + # Variable length components + diff_size_arr = [1.0, 2.0] + self.run_core_tests(lambda: self._build_dataset(diff_size_arr), + lambda: self._build_dataset(arr), 2) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 727c5d1c38ba30c32968a3cf33f7c03163f060d4..1f35ee056b7f897ce5e7488b205ecf5a05ef0268 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -11,6 +11,22 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +py_library( + name = "dataset_ops", + srcs = [ + "counter.py", + "dataset_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":transformation_ops", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + py_library( name = "iterator_ops", srcs = [ @@ -24,6 +40,25 @@ py_library( ], ) +py_library( + name = "random_ops", + srcs = [ + "random_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + py_library( name = "readers", srcs = [ @@ -46,6 +81,19 @@ py_library( ], ) +py_library( + name = "shuffle_ops", + srcs = [ + "shuffle_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":random_ops", + ":transformation_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_library( name = "transformation_ops", srcs = [ @@ -56,10 +104,10 @@ py_library( "interleave_ops.py", "resampling.py", "scan_ops.py", + "stats_ops.py", ], srcs_version = "PY2AND3", deps = [ - ":gen_dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dataset_ops_gen", @@ -71,8 +119,10 @@ py_library( "//tensorflow/python:random_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", "//third_party/py/numpy", ], ) @@ -104,39 +154,7 @@ tf_custom_op_py_library( deps = [ ":prefetching_ops", "//tensorflow/contrib/util:util_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - -tf_gen_op_wrapper_py( - name = "gen_dataset_ops", - out = "gen_dataset_ops.py", - deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"], -) - -tf_custom_op_py_library( - name = "dataset_ops", - srcs = ["dataset_ops.py"], - dso = ["//tensorflow/contrib/data:_dataset_ops.so"], - kernels = [ - "//tensorflow/contrib/data:dataset_ops_op_lib", - ], - srcs_version = "PY2AND3", - deps = [ - ":gen_dataset_ops", - ":transformation_ops", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:platform", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", ], ) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index e6e5f716b62b8d715eecf0c5a79d1c22d34c06b2..e8b2d44a8b57d471f11b128622b6121f699fbf85 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,14 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops @@ -103,6 +104,48 @@ def unbatch(): return _apply_fn +def filter_irregular_batches(batch_size): + """Transformation that filters out batches that are not of size batch_size.""" + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + tensor_batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + + flattened = _RestructuredDataset( + dataset, + tuple(nest.flatten(dataset.output_types)), + output_classes=tuple(nest.flatten(dataset.output_classes))) + + def _predicate(*xs): + """Return `True` if this element is a full batch.""" + # Extract the dynamic batch size from the first component of the flattened + # batched element. + first_component = xs[0] + first_component_batch_size = array_ops.shape( + first_component, out_type=dtypes.int64)[0] + + return math_ops.equal(first_component_batch_size, tensor_batch_size) + + filtered = flattened.filter(_predicate) + + maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) + + def _set_first_dimension(shape): + return shape.merge_with( + tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) + + known_shapes = nest.map_structure(_set_first_dimension, + dataset.output_shapes) + return _RestructuredDataset( + filtered, + dataset.output_types, + known_shapes, + output_classes=dataset.output_classes) + + return _apply_fn + + def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). @@ -135,34 +178,43 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - tensor_batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") + batched = dataset.batch(batch_size) + return filter_irregular_batches(batch_size)(batched) - batched = dataset.batch(tensor_batch_size) - flattened = _RestructuredDataset(batched, - tuple(nest.flatten(batched.output_types))) + return _apply_fn - def _predicate(*xs): - """Return `True` if this element is a full batch.""" - # Extract the dynamic batch size from the first component of the flattened - # batched element. - first_component = xs[0] - first_component_batch_size = array_ops.shape( - first_component, out_type=dtypes.int64)[0] - return math_ops.equal(first_component_batch_size, tensor_batch_size) +def padded_batch_and_drop_remainder(batch_size, + padded_shapes, + padding_values=None): + """A batching and padding transformation that omits the final small batch. - filtered = flattened.filter(_predicate) + Like @{tf.data.Dataset.padded_batch}, this transformation combines + consecutive elements of this dataset into batches. However, if the batch + size does not evenly divide the input dataset size, this transformation will + drop the final smaller element. - maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) + See `@{tf.contrib.data.batch_and_drop_remainder}` for more details. - def _set_first_dimension(shape): - return shape.merge_with( - tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) + Args: + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + padded_shapes: A nested structure of `tf.TensorShape` or + `tf.int64` vector tensor-like objects. See + @{tf.data.Dataset.padded_batch} for details. + padding_values: (Optional.) A nested structure of scalar-shaped + `tf.Tensor`. See @{tf.data.Dataset.padded_batch} for details. - known_shapes = nest.map_structure(_set_first_dimension, - batched.output_shapes) - return _RestructuredDataset(filtered, batched.output_types, known_shapes) + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply} + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + batched = dataset.padded_batch( + batch_size, padded_shapes=padded_shapes, padding_values=padding_values) + return filter_irregular_batches(batch_size)(batched) return _apply_fn @@ -191,6 +243,10 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): output_shapes=self.output_shapes, output_types=self.output_types) + @property + def output_classes(self): + return (ops.Tensor, ops.Tensor, ops.Tensor) + @property def output_shapes(self): num_elements = tensor_shape.Dimension(None) @@ -206,7 +262,11 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): class _RestructuredDataset(dataset_ops.Dataset): """An internal helper for changing the structure and shape of a dataset.""" - def __init__(self, dataset, output_types, output_shapes=None): + def __init__(self, + dataset, + output_types, + output_shapes=None, + output_classes=None): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: @@ -222,6 +282,8 @@ class _RestructuredDataset(dataset_ops.Dataset): output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. + output_classes: (Optional.) A nested structure of class types. + If omitted, the class types will be inherited from `dataset`. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible @@ -261,10 +323,21 @@ class _RestructuredDataset(dataset_ops.Dataset): output_shapes)) self._output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) + if output_classes is None: + # Inherit class types from the original `dataset`. + self._output_classes = nest.pack_sequence_as(output_types, + nest.flatten( + dataset.output_classes)) + else: + self._output_classes = output_classes def _as_variant_tensor(self): return self._dataset._as_variant_tensor() # pylint: disable=protected-access + @property + def output_classes(self): + return self._output_classes + @property def output_types(self): return self._output_types @@ -280,7 +353,6 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) - self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") self._num_parallel_batches = ops.convert_to_tensor( @@ -295,8 +367,10 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): f=self._map_func, batch_size=self._batch_size, num_parallel_batches=self._num_parallel_batches, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) # pylint: enable=protected-access @property @@ -316,17 +390,12 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset - and then combines them into a batch. Similarly to `batch_and_drop_remainder`, - if the batch size does not evenly divide the input dataset size, this - transformation will drop the final smaller element. - - - Functionally, it is equivalent to `map` followed by - `batch_and_drop_remainder`. However, by fusing the two transformations - together, the implementation can be more efficient. This transformation is a - stop gap solution for performance critical workloads. Once automatic input - pipeline optimization are implemented, the fusing of map and batch will not - need to be exposed at the API level and this method will be removed. + and then combines them into a batch. Functionally, it is equivalent to `map` + followed by `batch`. However, by fusing the two transformations together, the + implementation can be more efficient. Surfacing this transformation in the API + is temporary. Once automatic input pipeline optimization is implemented, + the fusing of `map` and `batch` will happen automatically and this API will be + deprecated. Args: map_func: A function mapping a nested structure of tensors to another diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py new file mode 100644 index 0000000000000000000000000000000000000000..63226fe78163c59025623a362d17c400fbe57c67 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/counter.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Counter Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import scan_ops + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops + + +def Counter(start=0, step=1, dtype=dtypes.int64): + """Creates a `Dataset` of a `step`-separated count startin from `start`. + + For example: + + ```python + Dataset.count() == [0, 1, 2, ...) + Dataset.count(2) == [2, 3, ...) + Dataset.count(2, 5) == [2, 7, 12, ...) + Dataset.count(0, -1) == [0, -1, -2, ...) + Dataset.count(10, -1) == [10, 9, ...) + ``` + + Args: + start: starting value for count. + step: step size. + dtype: counter data type. + + Returns: + A `Dataset` of scalar elements. + """ + with ops.name_scope("counter"): + start = ops.convert_to_tensor(start, dtype=dtype, name="start") + step = ops.convert_to_tensor(step, dtype=dtype, name="step") + return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( + scan_ops.scan(start, lambda state, _: (state + step, state))) diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index c4c4426809aa7b5a1c80a0d6f797b9e140be4dea..626a9e0edcea5928b1636c1a2a86e83657c966a5 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -20,21 +20,14 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.contrib.data.python.ops import error_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import grouping - -from tensorflow.contrib.util import loader from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_io_ops -from tensorflow.python.platform import resource_loader from tensorflow.python.util import deprecation -_dataset_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("../../_dataset_ops.so")) - - class Dataset(dataset_ops.Dataset): """Represents a potentially large set of elements. @@ -54,6 +47,10 @@ class Dataset(dataset_ops.Dataset): def _as_variant_tensor(self): return self._dataset._as_variant_tensor() # pylint: disable=protected-access + @property + def output_classes(self): + return self._dataset.output_classes + @property def output_shapes(self): return self._dataset.output_shapes diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 51a279107235f95eba2030291aab9d294f6d2b2d..aa629cba479102ee4244884e7c546615b28cf4e5 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -17,9 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.ops import gen_dataset_ops def ignore_errors(): @@ -62,8 +63,14 @@ class IgnoreErrorsDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.ignore_errors_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 1c7c94b3c84a8c48ba9237c323fc13777d25f43d..ef91c56726e969053fdad667dda3e89430045652 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -17,12 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops def group_by_window(key_func, @@ -87,15 +88,21 @@ def group_by_window(key_func, class _VariantDataset(dataset_ops.Dataset): """A Dataset wrapper for a tf.variant-typed function argument.""" - def __init__(self, dataset_variant, output_types, output_shapes): + def __init__(self, dataset_variant, output_types, output_shapes, + output_classes): super(_VariantDataset, self).__init__() self._dataset_variant = dataset_variant self._output_types = output_types self._output_shapes = output_shapes + self._output_classes = output_classes def _as_variant_tensor(self): return self._dataset_variant + @property + def output_classes(self): + return self._output_classes + @property def output_shapes(self): return self._output_shapes @@ -137,13 +144,21 @@ class GroupByWindowDataset(dataset_ops.Dataset): def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - @function.Defun(*nest.flatten(input_dataset.output_types)) + @function.Defun(*nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes))) def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. - for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) @@ -165,14 +180,15 @@ class GroupByWindowDataset(dataset_ops.Dataset): def tf_reduce_func(key, window_dataset_variant): """A wrapper for Defun that facilitates shape inference.""" key.set_shape([]) - window_dataset = _VariantDataset(window_dataset_variant, - input_dataset.output_types, - input_dataset.output_shapes) + window_dataset = _VariantDataset( + window_dataset_variant, input_dataset.output_types, + input_dataset.output_shapes, input_dataset.output_classes) if not isinstance(window_dataset, dataset_ops.Dataset): raise TypeError("`window_dataset` must return a `Dataset` object.") output_dataset = reduce_func(key, window_dataset) if not isinstance(output_dataset, dataset_ops.Dataset): raise TypeError("`reduce_func` must return a `Dataset` object.") + self._output_classes = output_dataset.output_classes self._output_types = output_dataset.output_types self._output_shapes = output_dataset.output_shapes return output_dataset._as_variant_tensor() # pylint: disable=protected-access @@ -180,6 +196,10 @@ class GroupByWindowDataset(dataset_ops.Dataset): self._reduce_func = tf_reduce_func self._reduce_func.add_to_graph(ops.get_default_graph()) + @property + def output_classes(self): + return self._output_classes + @property def output_shapes(self): return self._output_shapes @@ -197,5 +217,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): key_func=self._key_func, reduce_func=self._reduce_func, window_size_func=self._window_size_func, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index ce23e95697c9116635e6335dc7b1fdc6de514732..53324e06e7f1dc249388410f0e14e42336630cd1 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -17,12 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.util import deprecation @@ -35,16 +36,22 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): super(ParallelInterleaveDataset, self).__init__() self._input_dataset = input_dataset - @function.Defun(*nest.flatten(input_dataset.output_types)) + @function.Defun(*nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes))) def tf_map_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. - for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - - if nest.is_sequence(nested_args): + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) + if dataset_ops._should_unpack_args(nested_args): # pylint: disable=protected-access dataset = map_func(*nested_args) else: dataset = map_func(nested_args) @@ -52,6 +59,7 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`map_func` must return a `Dataset` object.") + self._output_classes = dataset.output_classes self._output_types = dataset.output_types self._output_shapes = dataset.output_shapes @@ -75,8 +83,14 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): self._block_length, self._sloppy, f=self._map_func, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + @property + def output_classes(self): + return self._output_classes @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index 32d2f42c9352fa35e3671ed549ad85efce2546d7..d736029fb035e573b70e8b19570e4e8ceca3c005 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -17,8 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.training import saver diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7d727165feabb101549567f28a2dfa07083de244 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -0,0 +1,67 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Datasets for random number generators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class RandomDataset(dataset_ops.Dataset): + """A `Dataset` of pseudorandom values.""" + + def __init__(self, seed=None): + """A `Dataset` of pseudorandom values.""" + super(RandomDataset, self).__init__() + seed, seed2 = random_seed.get_seed(seed) + if seed is None: + self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") + else: + self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") + if seed2 is None: + self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") + else: + self._seed2 = ops.convert_to_tensor( + seed2, dtype=dtypes.int64, name="seed2") + + def _as_variant_tensor(self): + return gen_dataset_ops.random_dataset( + seed=self._seed, + seed2=self._seed2, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.int64 diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index f22298b757c73dac096603335b475119e5971df4..347e5edc7b0d479dfa260e8cec500ffaaba375be 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -18,14 +18,13 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation @@ -156,8 +155,7 @@ def read_batch_features(file_pattern, features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. reader: A function or class that can be called with a `filenames` tensor - and (optional) `reader_args` and returns a `Dataset` of serialized - Examples. + and (optional) `reader_args` and returns a `Dataset` of Examples. reader_args: Additional arguments to pass to the reader class. randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the @@ -166,7 +164,7 @@ def read_batch_features(file_pattern, shuffling but would increase memory usage and startup time. Returns: - A dict from keys in features to Tensor or SparseTensor objects. + A dict from keys in features to `Tensor` or `SparseTensor` objects. """ filenames = _get_file_names(file_pattern, randomize_input) if reader_args: @@ -174,32 +172,17 @@ def read_batch_features(file_pattern, else: dataset = reader(filenames) if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset.map(lambda unused_k, v: v) - elif dataset.output_types != dtypes.string: - raise TypeError("`reader` must be a dataset of `tf.string` values, " - "or `(tf.string, tf.string)` key-value pairs.") + dataset = dataset.map(lambda _, v: v) if num_epochs != 1: dataset = dataset.repeat(num_epochs) if randomize_input: dataset = dataset.shuffle(capacity) dataset = dataset.batch(batch_size) - dataset = dataset.map(lambda x: _parse_example(x, features)) + dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features)) + dataset = dataset.prefetch(1) iterator = dataset.make_one_shot_iterator() outputs = iterator.get_next() - index = 0 - result = {} - for key in sorted(features.keys()): - feature = features[key] - if isinstance(feature, parsing_ops.FixedLenFeature): - result[key] = outputs[index] - index += 1 - else: - result[key] = sparse_tensor_lib.SparseTensor( - indices=outputs[index], - values=outputs[index + 1], - dense_shape=outputs[index + 2]) - index += 3 - return result + return outputs def _get_file_names(file_pattern, randomize_input): @@ -233,18 +216,6 @@ def _get_file_names(file_pattern, randomize_input): return file_names -def _parse_example(serialized, features): - parsed = parsing_ops.parse_example(serialized, features) - result = [] - for key in sorted(features.keys()): - val = parsed[key] - if isinstance(val, sparse_tensor_lib.SparseTensor): - result.extend([val.indices, val.values, val.dense_shape]) - else: - result.append(val) - return tuple(result) - - class SqlDataset(contrib_dataset_ops.Dataset): def __init__(self, driver_name, data_source_name, query, output_types): @@ -299,6 +270,10 @@ class _SqlDataset(dataset_ops.Dataset): nest.flatten(self.output_types), nest.flatten(self.output_shapes)) + @property + def output_classes(self): + return nest.map_structure(lambda _: ops.Tensor, self._output_types) + @property def output_shapes(self): return nest.map_structure(lambda _: tensor_shape.TensorShape([]), diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 87bbbb7d19b15955b507308ce2ea286f602efd37..2744786e9eec4c9268ba854df6ea761339bb0b4e 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -19,11 +19,12 @@ from __future__ import print_function import collections -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops class _ScanDataset(dataset_ops.Dataset): @@ -43,6 +44,7 @@ class _ScanDataset(dataset_ops.Dataset): # Compute initial values for the state shapes and types based on # the initial state. These will be refined by running # `tf_scan_func` one or more times below. + # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor. self._state_shapes = nest.pack_sequence_as( self._initial_state, [t.shape for t in nest.flatten(self._initial_state)]) @@ -51,6 +53,7 @@ class _ScanDataset(dataset_ops.Dataset): [t.dtype for t in nest.flatten(self._initial_state)]) # Will be populated by calling `tf_scan_func`. + self._output_classes = None self._output_shapes = None self._output_types = None @@ -65,14 +68,17 @@ class _ScanDataset(dataset_ops.Dataset): # Create a list in which `tf_scan_func` will store the s flat_new_state_shapes = [] - @function.Defun( - *(flat_state_types + nest.flatten(input_dataset.output_types))) + @function.Defun(*(flat_state_types + nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes)))) def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. - for arg, shape in zip( - args, - flat_state_shapes + nest.flatten(input_dataset.output_shapes)): + # TODO(b/69424092): Check that neither inputs nor outputs are sparse. + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, + flat_state_shapes + nest.flatten(dense_shapes)): arg.set_shape(shape) pivot = len(flat_state_shapes) @@ -106,6 +112,8 @@ class _ScanDataset(dataset_ops.Dataset): "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in flat_new_state]))) + self._output_classes = nest.pack_sequence_as( + output_value, [ops.Tensor for _ in flat_output_value]) self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in flat_output_value]) @@ -144,8 +152,14 @@ class _ScanDataset(dataset_ops.Dataset): nest.flatten(self._initial_state), self._scan_func.captured_inputs, f=self._scan_func, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + @property + def output_classes(self): + return self._output_classes @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..460732d65e4e652058ad821fbed45d365b4f41c1 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental shuffle ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import random_ops +from tensorflow.python.data.ops import dataset_ops + + +def shuffle_and_repeat(buffer_size, count=None, seed=None): + """Shuffles and repeats a Dataset returning a new permutation for each epoch. + + `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))` + + is equivalent to + + `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)` + + The difference is that the latter dataset is not serializable. So, + if you need to checkpoint an input pipeline with reshuffling you must use + this implementation. + + Args: + buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the + maximum number elements that will be buffered when prefetching. + count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + number of times the dataset should be repeated. The default behavior + (if `count` is `None` or `-1`) is for the dataset be repeated + indefinitely. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + @{tf.set_random_seed} for behavior. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + def _apply_fn(dataset): # pylint: disable=missing-docstring + random_ds = random_ops.RandomDataset(seed).apply( + batching.batch_and_drop_remainder(2)) + if count is not None and count is not -1: + random_ds = random_ds.take(count) + + def map_fn(seeds): + return dataset_ops.ShuffleDataset( + input_dataset=dataset, + buffer_size=buffer_size, + seed=seeds[0], + reshuffle_each_iteration=False, + seed2=seeds[1]) + + return random_ds.flat_map(map_fn) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b8875bd533ddc9e2c195646619dccf3aab5225e4 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -0,0 +1,177 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for gathering statistics from `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +class StatsAggregator(object): + """A stateful resource that aggregates statistics from one or more iterators. + + To record statistics, use one of the custom transformation functions defined + in this module when defining your @{tf.data.Dataset}. All statistics will be + aggregated by the `StatsAggregator` that is associated with a particular + iterator (see below). For example, to record the total number of bytes + produced by iterating over a dataset: + + ```python + dataset = ... + dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes")) + ``` + + To associate a `StatsAggregator` with a @{tf.data.Iterator} object, use + the following pattern: + + ```python + dataset = ... + iterator = dataset.make_one_shot_iterator() + stats_aggregator = stats_ops.StatsAggregator() + set_op = stats_op.set_stats_aggregator_op(iterator, stats_aggregator) + + with tf.Session() as sess: + # Running `set_op` will associate `iterator` with `stats_aggregator`. + sess.run(set_op) + ``` + + To get a protocol buffer summary of the currently aggregated statistics, + use the `StatsAggregator.get_summary()` tensor. The easiest way to do this + is to add the returned tensor to the @{tf.GraphKeys.SUMMARIES} collection, + so that the summaries will be included with any existing summaries. + + ```python + stats_aggregator = stats_ops.StatsAggregator() + stats_summary = stats_aggregator.get_summary() + tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary) + ``` + + Note: This interface is experimental and expected to change. In particular, + we expect to add other implementations of `StatsAggregator` that provide + different ways of exporting statistics, and add more types of statistics. + """ + + def __init__(self): + """Creates a `StatsAggregator`.""" + self._resource = gen_dataset_ops.stats_aggregator_handle() + + def get_summary(self): + """Returns a string @{tf.Tensor} that summarizes the aggregated statistics. + + The returned tensor will contain a serialized @{tf.summary.Summary} protocol + buffer, which can be used with the standard TensorBoard logging facilities. + + Returns: + A scalar string @{tf.Tensor} that summarizes the aggregated statistics. + """ + return gen_dataset_ops.stats_aggregator_summary(self._resource) + + def subscribe(self, iterator): + """Returns a @{tf.Operation} to associate this aggregator with `iterator`. + + Note: Each @{tf.data.Iterator} can be associated with at most one + `StatsAggregator`. After running the operation that this function + returns, all statistics recorded in the iteration of `iterator` + will be stored in `stats_aggregator`. + + Args: + iterator: A @{tf.data.Iterator} object. + + Returns: + A @{tf.Operation} that, when run, associates this aggregator with + `iterator`. + """ + if not isinstance(iterator, iterator_ops.Iterator): + raise TypeError("`iterator` must be a `tf.data.Iterator` object.") + return gen_dataset_ops.iterator_set_stats_aggregator( + iterator._iterator_resource, self._resource) # pylint: disable=protected-access + + +def bytes_produced_stats(tag): + """Records the number of bytes produced by each element of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with an iterator + over the output dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will + be associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.bytes_produced_stats_dataset, + tag) + + return _apply_fn + + +def latency_stats(tag): + """Records the latency of producing each element of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with an iterator + over the output dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will + be associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.latency_stats_dataset, tag) + + return _apply_fn + + +class _StatsDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and also records statistics.""" + + def __init__(self, input_dataset, op_function, tag): + super(_StatsDataset, self).__init__() + self._input_dataset = input_dataset + self._op_function = op_function + self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string) + + def _as_variant_tensor(self): + return self._op_function( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._tag, + output_shapes=nest.flatten(self.output_shapes), + output_types=nest.flatten(self.output_types)) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig index d3d201afd5761e7c5c136301c779222bedc68492..cafb9314caee1c4907786b8101e7c71bd7095306 100644 --- a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig +++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig @@ -2,7 +2,7 @@ %include "net/proto/swig/protofunc.swig" -#ifndef MUST_USE_RESULT +#ifndef ABSL_MUST_USE_RESULT #error Use this file only as a %include or %import after google.swig. #endif diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 145b9495ff40f8095b50d00e576333fdf5d7acdf..b2c641f8ab3ea23c5135042e4b1223d487ae8cbc 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -204,6 +204,24 @@ cuda_py_test( ], ) +cuda_py_test( + name = "half_normal_test", + size = "medium", + srcs = ["python/kernel_tests/half_normal_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "inverse_gamma_test", srcs = ["python/kernel_tests/inverse_gamma_test.py"], diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 0d12d838932e3a46e07f4a4242b889296c6e13c4..66827179e9fa1bea852f55246c263c4696cf3bdc 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -36,6 +36,7 @@ from tensorflow.contrib.distributions.python.ops.distribution_util import softpl from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag from tensorflow.contrib.distributions.python.ops.estimator import * from tensorflow.contrib.distributions.python.ops.geometric import * +from tensorflow.contrib.distributions.python.ops.half_normal import * from tensorflow.contrib.distributions.python.ops.independent import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * from tensorflow.contrib.distributions.python.ops.logistic import * @@ -107,6 +108,7 @@ _allowed_symbols = [ 'Gamma', 'GammaWithSoftplusConcentrationRate', 'Geometric', + 'HalfNormal', 'Independent', 'InverseGamma', 'InverseGammaWithSoftplusConcentrationRate', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py index 25a9b6f5fe2ed6d218d6b44650fce17fa89c0664..288d9d8dd6f17cd6348d3d72aea4408e26913ebd 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -22,9 +22,9 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import _gen_mask from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import masked_autoregressive_default_template from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import MaskedAutoregressiveFlow -from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import _gen_mask from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 38b3a23c2d684a6f89b7c4be4a763c649bf4de15..49451446b56d290f130c5db90c13b94974d92dc9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -28,8 +28,19 @@ from tensorflow.python.ops.distributions.bijector_test_util import assert_biject from tensorflow.python.platform import test -class ReshapeBijectorTest(test.TestCase): - """Tests correctness of the reshape transformation.""" +class _ReshapeBijectorTest(object): + """Base class for testing the reshape transformation. + + Methods defined in this class call a method self.build_shapes() that + is implemented by subclasses defined below, returning respectively + ReshapeBijectorTestStatic: static shapes, + ReshapeBijectorTestDynamic: shape placeholders of known ndims, and + ReshapeBijectorTestDynamicNdims: shape placeholders of unspecified ndims, + so that each test in this base class is automatically run over all + three cases. The subclasses also implement assertRaisesError to test + for either Python exceptions (in the case of static shapes) or + TensorFlow op errors (dynamic shapes). + """ def setUp(self): self._rng = np.random.RandomState(42) @@ -40,9 +51,10 @@ class ReshapeBijectorTest(test.TestCase): expected_y = np.reshape(expected_x, [4, 6]) with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,]) bijector = Reshape( - event_shape_out=[6,], - event_shape_in=[3, 2], + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) (x_, y_, @@ -52,66 +64,23 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.forward_log_det_jacobian(expected_x), bijector.inverse_log_det_jacobian(expected_y), - )) + ), feed_dict=feed_dict) self.assertEqual("reshape", bijector.name) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(0., fldj_, rtol=1e-6, atol=0) self.assertAllClose(0., ildj_, rtol=1e-6, atol=0) - def testEventShapeDynamicNdims(self): - """Check forward/inverse shape methods with dynamic ndims.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_ph = array_ops.placeholder(dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_ph = array_ops.placeholder(dtype=dtypes.int32) - - bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, validate_args=True) - - # using the _tensor methods, we should always get a fully-specified - # result since these are evaluated at graph runtime. - with self.test_session() as sess: - (shape_out_, - shape_in_) = sess.run(( - bijector.forward_event_shape_tensor(shape_in), - bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeDynamic(self): - """Check shape methods with static ndims but dynamic shape.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_partial = tensor_shape.TensorShape([None,]) - shape_in_ph = array_ops.placeholder( - shape=[1,], dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_partial = tensor_shape.TensorShape([None, None]) - shape_out_ph = array_ops.placeholder( - shape=[2,], dtype=dtypes.int32) + def testEventShapeTensor(self): + """Test event_shape_tensor methods when even ndims may be dynamic.""" + shape_in_static = [2, 3] + shape_out_static = [6,] + shape_in, shape_out, feed_dict = self.build_shapes(shape_in_static, + shape_out_static) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, - validate_args=True) - - # if event shapes are not statically available, should - # return partially-specified TensorShapes. - self.assertAllEqual( - bijector.forward_event_shape(shape_in).as_list(), - shape_out_partial.as_list()) - self.assertAllEqual( - bijector.inverse_event_shape(shape_out).as_list(), - shape_in_partial.as_list()) + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) # using the _tensor methods, we should always get a fully-specified # result since these are evaluated at graph runtime. @@ -120,42 +89,9 @@ class ReshapeBijectorTest(test.TestCase): shape_in_) = sess.run(( bijector.forward_event_shape_tensor(shape_in), bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeStatic(self): - """Check shape methods when shape is statically known.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_out = tensor_shape.TensorShape([2, 3]) - - bijector_static = Reshape( - event_shape_out=shape_out, - event_shape_in=shape_in, - validate_args=True) - - # test that forward_ and inverse_event_shape do sensible things - # when shapes are statically known. - self.assertEqual( - bijector_static.forward_event_shape(shape_in), - shape_out) - self.assertEqual( - bijector_static.inverse_event_shape(shape_out), - shape_in) - - with self.test_session() as sess: - (shape_out_static_, - shape_in_static_, - ) = sess.run(( - bijector_static.forward_event_shape_tensor(shape_in), - bijector_static.inverse_event_shape_tensor(shape_out), - )) - self.assertAllEqual(shape_out, shape_out_static_) - self.assertAllEqual(shape_in, shape_in_static_) + ), feed_dict=feed_dict) + self.assertAllEqual(shape_out_static, shape_out_) + self.assertAllEqual(shape_in_static, shape_in_) def testScalarReshape(self): """Test reshaping to and from a scalar shape ().""" @@ -166,11 +102,11 @@ class ReshapeBijectorTest(test.TestCase): expected_x_scalar = np.random.randn(1,) expected_y_scalar = expected_x_scalar[0] + shape_in, shape_out, feed_dict = self.build_shapes([], [1,]) with self.test_session() as sess: bijector = Reshape( - event_shape_out=[], - event_shape_in=[1,], validate_args=True) - + event_shape_out=shape_in, + event_shape_in=shape_out, validate_args=True) (x_, y_, x_scalar_, @@ -180,53 +116,178 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.inverse(expected_y_scalar), bijector.forward(expected_x_scalar), - )) + ), feed_dict=feed_dict) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0) self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0) - def testRaisesOpError(self): - x1 = np.random.randn(4, 2, 3) - x2 = np.random.randn(4, 3, 2) - x3 = np.random.randn(4, 5, 1, 1) + def testMultipleUnspecifiedDimensionsOpError(self): with self.test_session() as sess: - shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32) - shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32) + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,]) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) - with self.assertRaisesOpError( + with self.assertRaisesError( + "elements must have at most one `-1`."): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + + def testInvalidDimensionsOpError(self): + + with self.test_session() as sess: + + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError( + "elements must be either positive integers or `-1`."): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + + def testValidButNonMatchingInputOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + # Here we pass in a tensor (x) whose shape is compatible with + # the output shape, so tf.reshape will throw no error, but + # doesn't match the expected input shape. + with self.assertRaisesError( "Input `event_shape` does not match `event_shape_in`."): - sess.run(bijector.forward(x2), - feed_dict={shape_out_ph: [1, 6, 1], - shape_in_ph: [2, 3]}) + sess.run(bijector.forward(x), + feed_dict=feed_dict) - with self.assertRaisesOpError( - "event_shape_out entries must be positive."): - sess.run(bijector.forward(x1), - feed_dict={shape_out_ph: [-1, -1, 6], - shape_in_ph: [2, 3]}) + def testValidButNonMatchingInputPartiallySpecifiedOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError( + "Input `event_shape` does not match `event_shape_in`."): + sess.run(bijector.forward(x), + feed_dict=feed_dict) + + def testInputOutputMismatchOpError(self): + x1 = np.random.randn(4, 2, 3) + x2 = np.random.randn(4, 1, 1, 5) + + with self.test_session() as sess: + shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3], + [1, 1, 5]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) # test that *all* methods check basic assertions - fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]} - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): + with self.assertRaisesError( + "Input to reshape is a tensor with"): sess.run(bijector.forward(x1), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse(x3), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse_log_det_jacobian(x3), - feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.forward_log_det_jacobian(x1), - feed_dict=fd_mismatched) + with self.assertRaisesError( + "Input to reshape is a tensor with"): + sess.run(bijector.inverse(x2), feed_dict=fd_mismatched) + + def testOneShapePartiallySpecified(self): + expected_x = np.random.randn(4, 6) + expected_y = np.reshape(expected_x, [4, 2, 3]) + + with self.test_session() as sess: + # one of input/output shapes is partially specified + shape_in, shape_out, feed_dict = self.build_shapes([-1,], [2, 3]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testBothShapesPartiallySpecified(self): + expected_x = np.random.randn(4, 2, 3) + expected_y = np.reshape(expected_x, [4, 3, 2]) + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testDefaultVectorShape(self): + expected_x = np.random.randn(4, 4) + expected_y = np.reshape(expected_x, [4, 2, 2]) + with self.test_session() as sess: + _, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2]) + bijector = Reshape(shape_out, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def build_shapes(self, *args, **kwargs): + raise NotImplementedError("Subclass failed to implement `build_shapes`.") + + +class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_static = shape_in + shape_out_static = shape_out + feed_dict = {} + return shape_in_static, shape_out_static, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesRegexp(Exception, msg) + + def testEventShape(self): + shape_in_static = tensor_shape.TensorShape([2, 3]) + shape_out_static = tensor_shape.TensorShape([6,]) + bijector = Reshape( + event_shape_out=shape_out_static, + event_shape_in=shape_in_static, validate_args=True) + + # test that forward_ and inverse_event_shape do sensible things + # when shapes are statically known. + self.assertEqual( + bijector.forward_event_shape(shape_in_static), + shape_out_static) + self.assertEqual( + bijector.inverse_event_shape(shape_out_static), + shape_in_static) def testBijectiveAndFinite(self): x = np.random.randn(4, 2, 3) @@ -238,5 +299,32 @@ class ReshapeBijectorTest(test.TestCase): validate_args=True) assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + +class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=(len(shape_in),), + dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=(len(shape_out),), + dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + +class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py index 7f7697357ce7c77b2a50b87271d4ba7b49cbe05e..73747db31c86b67eaad5aeab7d5e80191e12b333 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py @@ -41,6 +41,7 @@ def try_import(name): # pylint: disable=invalid-name tf_logging.warning("Could not import %s: %s" % (name, str(e))) return module + stats = try_import("scipy.stats") @@ -62,9 +63,9 @@ class CauchyTest(test.TestCase): self.assertAllEqual(expected, scale_shape.eval()) loc = array_ops.zeros(loc_shape) scale = array_ops.ones(scale_shape) - self.assertAllEqual( - expected, - array_ops.shape(cauchy_lib.Cauchy(loc, scale).sample()).eval()) + self.assertAllEqual(expected, + array_ops.shape( + cauchy_lib.Cauchy(loc, scale).sample()).eval()) def _testParamStaticShapes(self, sample_shape, expected): param_shapes = cauchy_lib.Cauchy.param_static_shapes(sample_shape) @@ -92,8 +93,7 @@ class CauchyTest(test.TestCase): cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) log_pdf = cauchy.log_prob(x) - self.assertAllEqual(cauchy.batch_shape_tensor().eval(), - log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape) self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.eval().shape) self.assertAllEqual(cauchy.batch_shape, log_pdf.shape) @@ -115,16 +115,15 @@ class CauchyTest(test.TestCase): with self.test_session(): batch_size = 6 loc = constant_op.constant([[3.0, -3.0]] * batch_size) - scale = constant_op.constant([[np.sqrt(10.0), np.sqrt(15.0)]] * - batch_size) + scale = constant_op.constant( + [[np.sqrt(10.0), np.sqrt(15.0)]] * batch_size) x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) log_pdf = cauchy.log_prob(x) log_pdf_values = log_pdf.eval() self.assertEqual(log_pdf.shape, (6, 2)) - self.assertAllEqual(cauchy.batch_shape_tensor().eval(), - log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape) self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.eval().shape) self.assertAllEqual(cauchy.batch_shape, log_pdf.shape) @@ -248,8 +247,7 @@ class CauchyTest(test.TestCase): cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) entropy = cauchy.entropy() - self.assertAllEqual(cauchy.batch_shape_tensor().eval(), - entropy.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), entropy.shape) self.assertAllEqual(cauchy.batch_shape_tensor().eval(), entropy.eval().shape) self.assertAllEqual(cauchy.batch_shape, entropy.shape) @@ -257,7 +255,7 @@ class CauchyTest(test.TestCase): if not stats: return - expected_entropy = stats.cauchy(loc, scale).entropy() + expected_entropy = stats.cauchy(loc, scale[0]).entropy().reshape((1, 3)) self.assertAllClose(expected_entropy, entropy.eval()) def testCauchyMode(self): @@ -368,8 +366,8 @@ class CauchyTest(test.TestCase): self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, sample_values.shape) - expected_shape = (tensor_shape.TensorShape( - [n.eval()]).concatenate(cauchy.batch_shape)) + expected_shape = ( + tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape)) self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, sample_values.shape) @@ -385,18 +383,18 @@ class CauchyTest(test.TestCase): samples = cauchy.sample(n) sample_values = samples.eval() self.assertEqual(samples.shape, (100000, batch_size, 2)) - self.assertAllClose(np.median(sample_values[:, 0, 0]), - loc_v[0], atol=1e-1) - self.assertAllClose(np.median(sample_values[:, 0, 1]), - loc_v[1], atol=1e-1) + self.assertAllClose( + np.median(sample_values[:, 0, 0]), loc_v[0], atol=1e-1) + self.assertAllClose( + np.median(sample_values[:, 0, 1]), loc_v[1], atol=1e-1) expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( tensor_shape.TensorShape(cauchy.batch_shape_tensor().eval())) self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, sample_values.shape) - expected_shape = (tensor_shape.TensorShape( - [n.eval()]).concatenate(cauchy.batch_shape)) + expected_shape = ( + tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape)) self.assertAllEqual(expected_shape, samples.shape) self.assertAllEqual(expected_shape, sample_values.shape) @@ -428,9 +426,12 @@ class CauchyTest(test.TestCase): self.assertEqual(cauchy.event_shape, ()) self.assertAllEqual(cauchy.event_shape_tensor().eval(), []) self.assertAllEqual( - sess.run(cauchy.batch_shape_tensor(), - feed_dict={loc: 5.0, - scale: [1.0, 2.0]}), [2]) + sess.run( + cauchy.batch_shape_tensor(), + feed_dict={ + loc: 5.0, + scale: [1.0, 2.0] + }), [2]) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e75660083dc2edd1759a3a54e221d9e8a268c3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py @@ -0,0 +1,320 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import numpy as np + +from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +stats = try_import("scipy.stats") + + +class HalfNormalTest(test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(123) + + def assertAllFinite(self, tensor): + is_finite = np.isfinite(tensor.eval()) + all_true = np.ones_like(is_finite, dtype=np.bool) + self.assertAllEqual(all_true, is_finite) + + def _testParamShapes(self, sample_shape, expected): + with self.test_session(): + param_shapes = hn_lib.HalfNormal.param_shapes(sample_shape) + scale_shape = param_shapes["scale"] + self.assertAllEqual(expected, scale_shape.eval()) + scale = array_ops.ones(scale_shape) + self.assertAllEqual( + expected, + array_ops.shape(hn_lib.HalfNormal(scale).sample()).eval()) + + def _testParamStaticShapes(self, sample_shape, expected): + param_shapes = hn_lib.HalfNormal.param_static_shapes(sample_shape) + scale_shape = param_shapes["scale"] + self.assertEqual(expected, scale_shape) + + def _testBatchShapes(self, dist, tensor): + self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.shape) + self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.eval().shape) + self.assertAllEqual(dist.batch_shape, tensor.shape) + self.assertAllEqual(dist.batch_shape, tensor.eval().shape) + + def testParamShapes(self): + sample_shape = [10, 3, 4] + self._testParamShapes(sample_shape, sample_shape) + self._testParamShapes(constant_op.constant(sample_shape), sample_shape) + + def testParamStaticShapes(self): + sample_shape = [10, 3, 4] + self._testParamStaticShapes(sample_shape, sample_shape) + self._testParamStaticShapes( + tensor_shape.TensorShape(sample_shape), sample_shape) + + def testHalfNormalLogPDF(self): + with self.test_session(): + batch_size = 6 + scale = constant_op.constant([3.0] * batch_size) + x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) + halfnorm = hn_lib.HalfNormal(scale=scale) + + log_pdf = halfnorm.log_prob(x) + self._testBatchShapes(halfnorm, log_pdf) + + pdf = halfnorm.prob(x) + self._testBatchShapes(halfnorm, pdf) + + if not stats: + return + expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testHalfNormalLogPDFMultidimensional(self): + with self.test_session(): + batch_size = 6 + scale = constant_op.constant([[3.0, 1.0]] * batch_size) + x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T + halfnorm = hn_lib.HalfNormal(scale=scale) + + log_pdf = halfnorm.log_prob(x) + self._testBatchShapes(halfnorm, log_pdf) + + pdf = halfnorm.prob(x) + self._testBatchShapes(halfnorm, pdf) + + if not stats: + return + expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testHalfNormalCDF(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + halfnorm = hn_lib.HalfNormal(scale=scale) + + cdf = halfnorm.cdf(x) + self._testBatchShapes(halfnorm, cdf) + + log_cdf = halfnorm.log_cdf(x) + self._testBatchShapes(halfnorm, log_cdf) + + if not stats: + return + expected_logcdf = stats.halfnorm(scale=scale).logcdf(x) + self.assertAllClose(expected_logcdf, log_cdf.eval(), atol=0) + self.assertAllClose(np.exp(expected_logcdf), cdf.eval(), atol=0) + + def testHalfNormalSurvivalFunction(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sf = halfnorm.survival_function(x) + self._testBatchShapes(halfnorm, sf) + + log_sf = halfnorm.log_survival_function(x) + self._testBatchShapes(halfnorm, log_sf) + + if not stats: + return + expected_logsf = stats.halfnorm(scale=scale).logsf(x) + self.assertAllClose(expected_logsf, log_sf.eval(), atol=0) + self.assertAllClose(np.exp(expected_logsf), sf.eval(), atol=0) + + def testHalfNormalQuantile(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + p = np.linspace(0., 1.0, batch_size).astype(np.float64) + + halfnorm = hn_lib.HalfNormal(scale=scale) + x = halfnorm.quantile(p) + self._testBatchShapes(halfnorm, x) + + if not stats: + return + expected_x = stats.halfnorm(scale=scale).ppf(p) + self.assertAllClose(expected_x, x.eval(), atol=0) + + def testFiniteGradients(self): + for dtype in [np.float32, np.float64]: + g = ops.Graph() + with g.as_default(): + scale = variables.Variable(dtype(3.0)) + dist = hn_lib.HalfNormal(scale=scale) + x = np.array([0.01, 0.1, 1., 5., 10.]).astype(dtype) + for func in [ + dist.cdf, dist.log_cdf, dist.survival_function, + dist.log_prob, dist.prob, dist.log_survival_function, + ]: + print(func.__name__) + value = func(x) + grads = gradients_impl.gradients(value, [scale]) + with self.test_session(graph=g): + variables.global_variables_initializer().run() + self.assertAllFinite(value) + self.assertAllFinite(grads[0]) + + def testHalfNormalEntropy(self): + with self.test_session(): + scale = np.array([[1.0, 2.0, 3.0]]) + halfnorm = hn_lib.HalfNormal(scale=scale) + + # See https://en.wikipedia.org/wiki/Half-normal_distribution for the + # entropy formula used here. + expected_entropy = 0.5 * np.log(np.pi * scale ** 2.0 / 2.0) + 0.5 + + entropy = halfnorm.entropy() + self._testBatchShapes(halfnorm, entropy) + self.assertAllClose(expected_entropy, entropy.eval()) + + def testHalfNormalMeanAndMode(self): + with self.test_session(): + scale = np.array([11., 12., 13.]) + + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_mean = scale * np.sqrt(2.0) / np.sqrt(np.pi) + + self.assertAllEqual((3,), halfnorm.mean().eval().shape) + self.assertAllEqual(expected_mean, halfnorm.mean().eval()) + + self.assertAllEqual((3,), halfnorm.mode().eval().shape) + self.assertAllEqual([0., 0., 0.], halfnorm.mode().eval()) + + def testHalfNormalVariance(self): + with self.test_session(): + scale = np.array([7., 7., 7.]) + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi) + + self.assertAllEqual((3,), halfnorm.variance().eval().shape) + self.assertAllEqual(expected_variance, halfnorm.variance().eval()) + + def testHalfNormalStandardDeviation(self): + with self.test_session(): + scale = np.array([7., 7., 7.]) + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi) + + self.assertAllEqual((3,), halfnorm.stddev().shape) + self.assertAllEqual(np.sqrt(expected_variance), halfnorm.stddev().eval()) + + def testHalfNormalSample(self): + with self.test_session(): + scale = constant_op.constant(3.0) + n = constant_op.constant(100000) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sample = halfnorm.sample(n) + + self.assertEqual(sample.eval().shape, (100000,)) + self.assertAllClose(sample.eval().mean(), + 3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, sample.shape) + self.assertAllEqual(expected_shape, sample.eval().shape) + + expected_shape_static = (tensor_shape.TensorShape( + [n.eval()]).concatenate(halfnorm.batch_shape)) + self.assertAllEqual(expected_shape_static, sample.shape) + self.assertAllEqual(expected_shape_static, sample.eval().shape) + + def testHalfNormalSampleMultiDimensional(self): + with self.test_session(): + batch_size = 2 + scale = constant_op.constant([[2.0, 3.0]] * batch_size) + n = constant_op.constant(100000) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sample = halfnorm.sample(n) + self.assertEqual(sample.shape, (100000, batch_size, 2)) + self.assertAllClose(sample.eval()[:, 0, 0].mean(), + 2.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + self.assertAllClose(sample.eval()[:, 0, 1].mean(), + 3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, sample.shape) + self.assertAllEqual(expected_shape, sample.eval().shape) + + expected_shape_static = (tensor_shape.TensorShape( + [n.eval()]).concatenate(halfnorm.batch_shape)) + self.assertAllEqual(expected_shape_static, sample.shape) + self.assertAllEqual(expected_shape_static, sample.eval().shape) + + def testNegativeSigmaFails(self): + with self.test_session(): + halfnorm = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G") + with self.assertRaisesOpError("Condition x > 0 did not hold"): + halfnorm.mean().eval() + + def testHalfNormalShape(self): + with self.test_session(): + scale = constant_op.constant([6.0] * 5) + halfnorm = hn_lib.HalfNormal(scale=scale) + + self.assertEqual(halfnorm.batch_shape_tensor().eval(), [5]) + self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(halfnorm.event_shape_tensor().eval(), []) + self.assertEqual(halfnorm.event_shape, tensor_shape.TensorShape([])) + + def testHalfNormalShapeWithPlaceholders(self): + scale = array_ops.placeholder(dtype=dtypes.float32) + halfnorm = hn_lib.HalfNormal(scale=scale) + + with self.test_session() as sess: + # get_batch_shape should return an "" tensor. + self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape(None)) + self.assertEqual(halfnorm.event_shape, ()) + self.assertAllEqual(halfnorm.event_shape_tensor().eval(), []) + self.assertAllEqual( + sess.run(halfnorm.batch_shape_tensor(), + feed_dict={scale: [1.0, 2.0]}), [2]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py index ece6bc077d9e21502fdfd01300a9d3e9f2c9c380..ff6092fc260660b512e8123823c63e98a023af6d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py @@ -45,6 +45,17 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers, self.assertEqual([4, 5], x.shape) self.assertEqual([4, 5], log_prob_x.shape) + def testSampleAndLogProbBatch(self): + with self.test_session(): + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[[0.3, 0.7]]), + components_distribution=normal_lib.Normal( + loc=[[-1., 1]], scale=[[0.1, 0.5]])) + x = gm.sample([4, 5], seed=42) + log_prob_x = gm.log_prob(x) + self.assertEqual([4, 5, 1], x.shape) + self.assertEqual([4, 5, 1], log_prob_x.shape) + def testSampleAndLogProbShapesBroadcastMix(self): mix_probs = np.float32([.3, .7]) bern_probs = np.float32([[.4, .6], [.25, .75]]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 103d8e186221e879d1734a097114708429f725bd..cbaf74d3f66253ae5727e1ba579e2d49235b748e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -200,6 +200,27 @@ class TransformedDistributionTest(test.TestCase): self.assertAllEqual([2], multi_logit_normal.event_shape) self.assertAllEqual([2], multi_logit_normal.event_shape_tensor().eval()) + def testCastLogDetJacobian(self): + """Test log_prob when Jacobian and log_prob dtypes do not match.""" + + with self.test_session(): + # Create an identity bijector whose jacobians have dtype int32 + int_identity = bs.Inline( + forward_fn=array_ops.identity, + inverse_fn=array_ops.identity, + inverse_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + forward_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + is_constant_jacobian=True) + normal = self._cls()( + distribution=ds.Normal(loc=0., scale=1.), + bijector=int_identity, + validate_args=True) + + y = normal.sample() + normal.log_prob(y).eval() + normal.prob(y).eval() + normal.entropy().eval() + def testEntropy(self): with self.test_session(): shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index 6049419818e18c54209f0be95d41fcecf6627b7e..0fe9f6aa78fbe845b99d0668f075b0162ec2a9f7 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -18,12 +18,117 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["AbsoluteValue"] +__all__ = [ + "AbsoluteValue", +] -remove_undocumented(__name__, _allowed_symbols) + +class AbsoluteValue(bijector.Bijector): + """Computes `Y = g(X) = Abs(X)`, element-wise. + + This non-injective bijector allows for transformations of scalar distributions + with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. + + * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse + `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. + * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse + (the set inverse is the singleton `{0}`), but "works" in conjunction with + `TransformedDistribution` to produce a left semi-continuous pdf. + * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the + wrong thing, `-y, y`. This is done for efficiency. If + `validate_args == True`, `y < 0` will raise an exception. + + + ```python + tfd = tf.contrib.distributions + + abs = tfd.bijectors.AbsoluteValue() + + abs.forward([-1., 0., 1.]) + ==> [1., 0., 1.] + + abs.inverse(1.) + ==> [-1., 1.] + + # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. + abs.inverse_log_det_jacobian(1.) + ==> [0., 0.] + + # Special case handling of 0. + abs.inverse(0.) + ==> [0., 0.] + + abs.inverse_log_det_jacobian(0.) + ==> [0., 0.] + ``` + + """ + + def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): + """Instantiates the `AbsoluteValue` bijector. + + Args: + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. Currently only zero is + supported. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness, in particular whether inputs to `inverse` and + `inverse_log_det_jacobian` are non-negative. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: If `event_ndims` is not zero. + """ + self._graph_parents = [] + self._name = name + + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0,): + raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) + else: + if validate_args: + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_equal( + event_ndims, 0, message="event_ndims was not 0")], + event_ndims) + + with self._name_scope("init"): + super(AbsoluteValue, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return math_ops.abs(x) + + def _inverse(self, y): + if self.validate_args: + y = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + y) + return -y, y + + def _inverse_log_det_jacobian(self, y): + # If event_ndims = 2, + # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), + # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. + batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] + zeros = array_ops.zeros(batch_shape, dtype=y.dtype) + if self.validate_args: + zeros = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + zeros) + return zeros, zeros + + @property + def _is_injective(self): + return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py deleted file mode 100644 index b84502003ab6c0c4ffdda21eea162f441509e1fa..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""AbsoluteValue bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "AbsoluteValue", -] - - -class AbsoluteValue(bijector.Bijector): - """Computes `Y = g(X) = Abs(X)`, element-wise. - - This non-injective bijector allows for transformations of scalar distributions - with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. - - * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse - `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. - * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse - (the set inverse is the singleton `{0}`), but "works" in conjunction with - `TransformedDistribution` to produce a left semi-continuous pdf. - * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the - wrong thing, `-y, y`. This is done for efficiency. If - `validate_args == True`, `y < 0` will raise an exception. - - - ```python - abs = ds.bijectors.AbsoluteValue() - - abs.forward([-1., 0., 1.]) - ==> [1., 0., 1.] - - abs.inverse(1.) - ==> [-1., 1.] - - # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. - abs.inverse_log_det_jacobian(1.) - ==> [0., 0.] - - # Special case handling of 0. - abs.inverse(0.) - ==> [0., 0.] - - abs.inverse_log_det_jacobian(0.) - ==> [0., 0.] - ``` - - """ - - def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): - """Instantiates the `AbsoluteValue` bijector. - - Args: - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. Currently only zero is - supported. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness, in particular whether inputs to `inverse` and - `inverse_log_det_jacobian` are non-negative. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: If `event_ndims` is not zero. - """ - self._graph_parents = [] - self._name = name - - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0,): - raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) - else: - if validate_args: - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_equal( - event_ndims, 0, message="event_ndims was not 0")], - event_ndims) - - with self._name_scope("init"): - super(AbsoluteValue, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - return math_ops.abs(x) - - def _inverse(self, y): - if self.validate_args: - y = control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(y, message="Argument y was negative")], - y) - return -y, y - - def _inverse_log_det_jacobian(self, y): - # If event_ndims = 2, - # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), - # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. - batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] - zeros = array_ops.zeros(batch_shape, dtype=y.dtype) - if self.validate_args: - zeros = control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(y, message="Argument y was negative")], - zeros) - return zeros, zeros - - @property - def _is_injective(self): - return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index 940cceff04e77cfc2f7caae5a798d135f7601b95..05bb9c2f9bdf35e222c94db3491157893da64ebd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -18,12 +18,386 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.affine_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib import linalg +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Affine"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Affine", +] + + +def _as_tensor(x, name): + """Convenience to convert to `Tensor` or leave as `None`.""" + return None if x is None else ops.convert_to_tensor(x, name=name) + + +class Affine(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale @ X + shift`. + + Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. + + In TF parlance, the `scale` term is logically equivalent to: + + ```python + scale = ( + scale_identity_multiplier * tf.diag(tf.ones(d)) + + tf.diag(scale_diag) + + scale_tril + + scale_perturb_factor @ diag(scale_perturb_diag) @ + tf.transpose([scale_perturb_factor]) + ) + ``` + + The `scale` term is applied without necessarily materializing constituent + matrices, i.e., the matmul is [matrix-free]( + https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. + + Examples: + + ```python + # Y = X + b = Affine() + + # Y = X + shift + b = Affine(shift=[1., 2, 3]) + + # Y = 2 * I @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_identity_multiplier=2.) + + # Y = tf.diag(d1) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_diag=[-1., 2, 1]) # Implicitly 3x3. + + # Y = (I + v * v.T) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_perturb_factor=[[1., 0], + [0, 1], + [1, 1]]) + + # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_diag=[1., 3, 3], # Implicitly 3x3. + scale_perturb_diag=[2., 1], # Implicitly 2x2. + scale_perturb_factor=[[1., 0], + [0, 1], + [1, 1]]) + + ``` + + """ + + def __init__(self, + shift=None, + scale_identity_multiplier=None, + scale_diag=None, + scale_tril=None, + scale_perturb_factor=None, + scale_perturb_diag=None, + event_ndims=1, + validate_args=False, + name="affine"): + """Instantiates the `Affine` bijector. + + This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, + giving the forward operation: + + ```none + Y = g(X) = scale @ X + shift + ``` + + where the `scale` term is logically equivalent to: + + ```python + scale = ( + scale_identity_multiplier * tf.diag(tf.ones(d)) + + tf.diag(scale_diag) + + scale_tril + + scale_perturb_factor @ diag(scale_perturb_diag) @ + tf.transpose([scale_perturb_factor]) + ) + ``` + + If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are + specified then `scale += IdentityMatrix`. Otherwise specifying a + `scale` argument has the semantics of `scale += Expand(arg)`, i.e., + `scale_diag != None` means `scale += tf.diag(scale_diag)`. + + Args: + shift: Floating-point `Tensor`. If this is set to `None`, no shift is + applied. + scale_identity_multiplier: floating point rank 0 `Tensor` representing a + scaling done to the identity matrix. + When `scale_identity_multiplier = scale_diag = scale_tril = None` then + `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added + to `scale`. + scale_diag: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k], which represents a k x k + diagonal matrix. + When `None` no diagonal term is added to `scale`. + scale_tril: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k + lower triangular matrix. + When `None` no `scale_tril` term is added to `scale`. + The upper triangular elements above the diagonal are ignored. + scale_perturb_factor: Floating-point `Tensor` representing factor matrix + with last two dimensions of shape `(k, r)`. When `None`, no rank-r + update is added to `scale`. + scale_perturb_diag: Floating-point `Tensor` representing the diagonal + matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which + represents an `r x r` diagonal matrix. When `None` low rank updates will + take the form `scale_perturb_factor * scale_perturb_factor.T`. + event_ndims: Scalar `int` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. Must be 0 or 1. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `perturb_diag` is specified but not `perturb_factor`. + TypeError: if `shift` has different `dtype` from `scale` arguments. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + # Ambiguous definition of low rank update. + if scale_perturb_diag is not None and scale_perturb_factor is None: + raise ValueError("When scale_perturb_diag is specified, " + "scale_perturb_factor must be specified.") + + # Special case, only handling a scaled identity matrix. We don't know its + # dimensions, so this is special cased. + # We don't check identity_multiplier, since below we set it to 1. if all + # other scale args are None. + self._is_only_identity_multiplier = (scale_tril is None and + scale_diag is None and + scale_perturb_factor is None) + + with self._name_scope("init", values=[ + shift, scale_identity_multiplier, scale_diag, scale_tril, + scale_perturb_diag, scale_perturb_factor]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0, 1): + raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) + else: + if validate_args: + # Shape tool will catch if event_ndims is negative. + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_less( + event_ndims, 2, message="event_ndims must be 0 or 1")], + event_ndims) + + if event_ndims_const == 0 and not self._is_only_identity_multiplier: + raise ValueError( + "If event_ndims == 0, the only scale argument you can pass is " + "scale_identity_multiplier. All others operate on vectors.") + + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. + dtype = dtypes.float32 + + if shift is not None: + shift = ops.convert_to_tensor(shift, name="shift") + dtype = shift.dtype.base_dtype + self._shift = shift + + # When no args are specified, pretend the scale matrix is the identity + # matrix. + if (self._is_only_identity_multiplier and + scale_identity_multiplier is None): + scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype) + + # self._create_scale_operator returns a LinearOperator in all cases + # except if self._is_only_identity_multiplier; in which case it + # returns a scalar Tensor. + scale = self._create_scale_operator( + identity_multiplier=scale_identity_multiplier, + diag=scale_diag, + tril=scale_tril, + perturb_diag=scale_perturb_diag, + perturb_factor=scale_perturb_factor, + shift=shift, + validate_args=validate_args) + + if scale.dtype is not None: + dtype = scale.dtype.base_dtype + + if scale is not None and not self._is_only_identity_multiplier: + if (shift is not None and + shift.dtype.base_dtype != scale.dtype.base_dtype): + raise TypeError( + "shift.dtype({}) is incompatible with scale.dtype({}).".format( + shift.dtype, scale.dtype)) + + if scale.tensor_rank is not None: + batch_ndims = scale.tensor_rank - 2 + else: + batch_ndims = scale.tensor_rank_tensor() - 2 + else: + # We won't need shape inference when scale is None or when scale is a + # scalar. + batch_ndims = 0 + self._scale = scale + self._shaper = _DistributionShape( + batch_ndims=batch_ndims, + event_ndims=event_ndims, + validate_args=validate_args) + super(Affine, self).__init__( + event_ndims=event_ndims, + graph_parents=( + [event_ndims] + + [self._scale] if tensor_util.is_tensor(self._scale) + else self._scale.graph_parents + + [self._shift] if self._shift is not None else []), + is_constant_jacobian=True, + dtype=dtype, + validate_args=validate_args, + name=name) + + def _create_scale_operator(self, identity_multiplier, diag, tril, + perturb_diag, perturb_factor, shift, + validate_args): + """Construct `scale` from various components. + + Args: + identity_multiplier: floating point rank 0 `Tensor` representing a scaling + done to the identity matrix. + diag: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k], which represents a k x k + diagonal matrix. + tril: Floating-point `Tensor` representing the diagonal matrix. + `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower + triangular matrix. + perturb_diag: Floating-point `Tensor` representing the diagonal matrix of + the low rank update. + perturb_factor: Floating-point `Tensor` representing factor matrix. + shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + + Returns: + scale. In the case of scaling by a constant, scale is a + floating point `Tensor`. Otherwise, scale is a `LinearOperator`. + + Raises: + ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. + """ + identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") + diag = _as_tensor(diag, "diag") + tril = _as_tensor(tril, "tril") + perturb_diag = _as_tensor(perturb_diag, "perturb_diag") + perturb_factor = _as_tensor(perturb_factor, "perturb_factor") + + # If possible, use the low rank update to infer the shape of + # the identity matrix, when scale represents a scaled identity matrix + # with a low rank update. + shape_hint = None + if perturb_factor is not None: + shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) + + if self._is_only_identity_multiplier: + if validate_args: + return control_flow_ops.with_dependencies( + [check_ops.assert_none_equal( + identity_multiplier, + array_ops.zeros([], identity_multiplier.dtype), + ["identity_multiplier should be non-zero."])], + identity_multiplier) + return identity_multiplier + + scale = distribution_util.make_tril_scale( + loc=shift, + scale_tril=tril, + scale_diag=diag, + scale_identity_multiplier=identity_multiplier, + validate_args=validate_args, + assert_positive=False, + shape_hint=shape_hint) + + if perturb_factor is not None: + return linalg.LinearOperatorLowRankUpdate( + scale, + u=perturb_factor, + diag_update=perturb_diag, + is_diag_update_positive=perturb_diag is None, + is_non_singular=True, # Implied by is_positive_definite=True. + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + return scale + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = x + if self._is_only_identity_multiplier: + y *= self._scale + if self.shift is not None: + return y + self.shift + return y + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + y, expand_batch_dim=False) + with ops.control_dependencies(self._maybe_check_scale() if + self.validate_args else []): + y = self.scale.matmul(y) + y = self._shaper.undo_make_batch_of_event_sample_matrices( + y, sample_shape, expand_batch_dim=False) + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = y + if self.shift is not None: + x -= self.shift + if self._is_only_identity_multiplier: + return x / self._scale + + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + x, expand_batch_dim=False) + # Solve fails if the op is singular so we may safely skip this assertion. + x = self.scale.solve(x) + x = self._shaper.undo_make_batch_of_event_sample_matrices( + x, sample_shape, expand_batch_dim=False) + return x + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(y) + + def _forward_log_det_jacobian(self, x): + if self._is_only_identity_multiplier: + # We don't pad in this case and instead let the fldj be applied + # via broadcast. + event_size = distribution_util.pick_vector( + math_ops.equal(self._shaper.event_ndims, 0), + [1], array_ops.shape(x))[-1] + event_size = math_ops.cast(event_size, dtype=self._scale.dtype) + return math_ops.log(math_ops.abs(self._scale)) * event_size + return self.scale.log_abs_determinant() + + def _maybe_check_scale(self): + try: + return [self.scale.assert_non_singular()] + except NotImplementedError: + pass + return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py deleted file mode 100644 index 05bb9c2f9bdf35e222c94db3491157893da64ebd..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py +++ /dev/null @@ -1,403 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Affine bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import linalg -from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Affine", -] - - -def _as_tensor(x, name): - """Convenience to convert to `Tensor` or leave as `None`.""" - return None if x is None else ops.convert_to_tensor(x, name=name) - - -class Affine(bijector.Bijector): - """Compute `Y = g(X; shift, scale) = scale @ X + shift`. - - Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. - - In TF parlance, the `scale` term is logically equivalent to: - - ```python - scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + - scale_tril + - scale_perturb_factor @ diag(scale_perturb_diag) @ - tf.transpose([scale_perturb_factor]) - ) - ``` - - The `scale` term is applied without necessarily materializing constituent - matrices, i.e., the matmul is [matrix-free]( - https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. - - Examples: - - ```python - # Y = X - b = Affine() - - # Y = X + shift - b = Affine(shift=[1., 2, 3]) - - # Y = 2 * I @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_identity_multiplier=2.) - - # Y = tf.diag(d1) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_diag=[-1., 2, 1]) # Implicitly 3x3. - - # Y = (I + v * v.T) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_perturb_factor=[[1., 0], - [0, 1], - [1, 1]]) - - # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_diag=[1., 3, 3], # Implicitly 3x3. - scale_perturb_diag=[2., 1], # Implicitly 2x2. - scale_perturb_factor=[[1., 0], - [0, 1], - [1, 1]]) - - ``` - - """ - - def __init__(self, - shift=None, - scale_identity_multiplier=None, - scale_diag=None, - scale_tril=None, - scale_perturb_factor=None, - scale_perturb_diag=None, - event_ndims=1, - validate_args=False, - name="affine"): - """Instantiates the `Affine` bijector. - - This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, - giving the forward operation: - - ```none - Y = g(X) = scale @ X + shift - ``` - - where the `scale` term is logically equivalent to: - - ```python - scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + - scale_tril + - scale_perturb_factor @ diag(scale_perturb_diag) @ - tf.transpose([scale_perturb_factor]) - ) - ``` - - If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are - specified then `scale += IdentityMatrix`. Otherwise specifying a - `scale` argument has the semantics of `scale += Expand(arg)`, i.e., - `scale_diag != None` means `scale += tf.diag(scale_diag)`. - - Args: - shift: Floating-point `Tensor`. If this is set to `None`, no shift is - applied. - scale_identity_multiplier: floating point rank 0 `Tensor` representing a - scaling done to the identity matrix. - When `scale_identity_multiplier = scale_diag = scale_tril = None` then - `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added - to `scale`. - scale_diag: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k], which represents a k x k - diagonal matrix. - When `None` no diagonal term is added to `scale`. - scale_tril: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k - lower triangular matrix. - When `None` no `scale_tril` term is added to `scale`. - The upper triangular elements above the diagonal are ignored. - scale_perturb_factor: Floating-point `Tensor` representing factor matrix - with last two dimensions of shape `(k, r)`. When `None`, no rank-r - update is added to `scale`. - scale_perturb_diag: Floating-point `Tensor` representing the diagonal - matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which - represents an `r x r` diagonal matrix. When `None` low rank updates will - take the form `scale_perturb_factor * scale_perturb_factor.T`. - event_ndims: Scalar `int` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `perturb_diag` is specified but not `perturb_factor`. - TypeError: if `shift` has different `dtype` from `scale` arguments. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - - # Ambiguous definition of low rank update. - if scale_perturb_diag is not None and scale_perturb_factor is None: - raise ValueError("When scale_perturb_diag is specified, " - "scale_perturb_factor must be specified.") - - # Special case, only handling a scaled identity matrix. We don't know its - # dimensions, so this is special cased. - # We don't check identity_multiplier, since below we set it to 1. if all - # other scale args are None. - self._is_only_identity_multiplier = (scale_tril is None and - scale_diag is None and - scale_perturb_factor is None) - - with self._name_scope("init", values=[ - shift, scale_identity_multiplier, scale_diag, scale_tril, - scale_perturb_diag, scale_perturb_factor]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0, 1): - raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - - if event_ndims_const == 0 and not self._is_only_identity_multiplier: - raise ValueError( - "If event_ndims == 0, the only scale argument you can pass is " - "scale_identity_multiplier. All others operate on vectors.") - - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. - dtype = dtypes.float32 - - if shift is not None: - shift = ops.convert_to_tensor(shift, name="shift") - dtype = shift.dtype.base_dtype - self._shift = shift - - # When no args are specified, pretend the scale matrix is the identity - # matrix. - if (self._is_only_identity_multiplier and - scale_identity_multiplier is None): - scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype) - - # self._create_scale_operator returns a LinearOperator in all cases - # except if self._is_only_identity_multiplier; in which case it - # returns a scalar Tensor. - scale = self._create_scale_operator( - identity_multiplier=scale_identity_multiplier, - diag=scale_diag, - tril=scale_tril, - perturb_diag=scale_perturb_diag, - perturb_factor=scale_perturb_factor, - shift=shift, - validate_args=validate_args) - - if scale.dtype is not None: - dtype = scale.dtype.base_dtype - - if scale is not None and not self._is_only_identity_multiplier: - if (shift is not None and - shift.dtype.base_dtype != scale.dtype.base_dtype): - raise TypeError( - "shift.dtype({}) is incompatible with scale.dtype({}).".format( - shift.dtype, scale.dtype)) - - if scale.tensor_rank is not None: - batch_ndims = scale.tensor_rank - 2 - else: - batch_ndims = scale.tensor_rank_tensor() - 2 - else: - # We won't need shape inference when scale is None or when scale is a - # scalar. - batch_ndims = 0 - self._scale = scale - self._shaper = _DistributionShape( - batch_ndims=batch_ndims, - event_ndims=event_ndims, - validate_args=validate_args) - super(Affine, self).__init__( - event_ndims=event_ndims, - graph_parents=( - [event_ndims] + - [self._scale] if tensor_util.is_tensor(self._scale) - else self._scale.graph_parents + - [self._shift] if self._shift is not None else []), - is_constant_jacobian=True, - dtype=dtype, - validate_args=validate_args, - name=name) - - def _create_scale_operator(self, identity_multiplier, diag, tril, - perturb_diag, perturb_factor, shift, - validate_args): - """Construct `scale` from various components. - - Args: - identity_multiplier: floating point rank 0 `Tensor` representing a scaling - done to the identity matrix. - diag: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k], which represents a k x k - diagonal matrix. - tril: Floating-point `Tensor` representing the diagonal matrix. - `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower - triangular matrix. - perturb_diag: Floating-point `Tensor` representing the diagonal matrix of - the low rank update. - perturb_factor: Floating-point `Tensor` representing factor matrix. - shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - - Returns: - scale. In the case of scaling by a constant, scale is a - floating point `Tensor`. Otherwise, scale is a `LinearOperator`. - - Raises: - ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. - """ - identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") - diag = _as_tensor(diag, "diag") - tril = _as_tensor(tril, "tril") - perturb_diag = _as_tensor(perturb_diag, "perturb_diag") - perturb_factor = _as_tensor(perturb_factor, "perturb_factor") - - # If possible, use the low rank update to infer the shape of - # the identity matrix, when scale represents a scaled identity matrix - # with a low rank update. - shape_hint = None - if perturb_factor is not None: - shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) - - if self._is_only_identity_multiplier: - if validate_args: - return control_flow_ops.with_dependencies( - [check_ops.assert_none_equal( - identity_multiplier, - array_ops.zeros([], identity_multiplier.dtype), - ["identity_multiplier should be non-zero."])], - identity_multiplier) - return identity_multiplier - - scale = distribution_util.make_tril_scale( - loc=shift, - scale_tril=tril, - scale_diag=diag, - scale_identity_multiplier=identity_multiplier, - validate_args=validate_args, - assert_positive=False, - shape_hint=shape_hint) - - if perturb_factor is not None: - return linalg.LinearOperatorLowRankUpdate( - scale, - u=perturb_factor, - diag_update=perturb_diag, - is_diag_update_positive=perturb_diag is None, - is_non_singular=True, # Implied by is_positive_definite=True. - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - return scale - - @property - def shift(self): - """The `shift` `Tensor` in `Y = scale @ X + shift`.""" - return self._shift - - @property - def scale(self): - """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" - return self._scale - - def _forward(self, x): - y = x - if self._is_only_identity_multiplier: - y *= self._scale - if self.shift is not None: - return y + self.shift - return y - y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - y, expand_batch_dim=False) - with ops.control_dependencies(self._maybe_check_scale() if - self.validate_args else []): - y = self.scale.matmul(y) - y = self._shaper.undo_make_batch_of_event_sample_matrices( - y, sample_shape, expand_batch_dim=False) - if self.shift is not None: - y += self.shift - return y - - def _inverse(self, y): - x = y - if self.shift is not None: - x -= self.shift - if self._is_only_identity_multiplier: - return x / self._scale - - x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - x, expand_batch_dim=False) - # Solve fails if the op is singular so we may safely skip this assertion. - x = self.scale.solve(x) - x = self._shaper.undo_make_batch_of_event_sample_matrices( - x, sample_shape, expand_batch_dim=False) - return x - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): - if self._is_only_identity_multiplier: - # We don't pad in this case and instead let the fldj be applied - # via broadcast. - event_size = distribution_util.pick_vector( - math_ops.equal(self._shaper.event_ndims, 0), - [1], array_ops.shape(x))[-1] - event_size = math_ops.cast(event_size, dtype=self._scale.dtype) - return math_ops.log(math_ops.abs(self._scale)) * event_size - return self.scale.log_abs_determinant() - - def _maybe_check_scale(self): - try: - return [self.scale.assert_non_singular()] - except NotImplementedError: - pass - return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index aca04a89df7c3ee09d5f7cc10f6779e33fa7aa66..89043b1410370074f11f2cfa59b6b6663fa62521 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -18,12 +18,214 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.linalg import linear_operator -_allowed_symbols = ["AffineLinearOperator"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "AffineLinearOperator", +] + + +class AffineLinearOperator(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale @ X + shift`. + + `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. + + If `X` is a scalar then the forward transformation is: `scale * X + shift` + where `*` denotes the scalar product. + + Note: we don't always simply transpose `X` (but write it this way for + brevity). Actually the input `X` undergoes the following transformation + before being premultiplied by `scale`: + + 1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e., + `new_sample_shape = [1]`. Otherwise do nothing. + 2. The sample shape is flattened to have one dimension, i.e., + `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`. + 3. The sample dim is cyclically rotated left by 1, i.e., + `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the + event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch + dimensions. + + (For more details see `shape.make_batch_of_event_sample_matrices`.) + + The result of the above transformation is that `X` can be regarded as a batch + of matrices where each column is a draw from the distribution. After + premultiplying by `scale`, we take the inverse of this procedure. The input + `Y` also undergoes the same transformation before/after premultiplying by + `inv(scale)`. + + Example Use: + + ```python + linalg = tf.linalg + + x = [1., 2, 3] + + shift = [-1., 0., 1] + diag = [1., 2, 3] + scale = linalg.LinearOperatorDiag(diag) + affine = AffineLinearOperator(shift, scale) + # In this case, `forward` is equivalent to: + # y = scale @ x + shift + y = affine.forward(x) # [0., 4, 10] + + shift = [2., 3, 1] + tril = [[1., 0, 0], + [2, 1, 0], + [3, 2, 1]] + scale = linalg.LinearOperatorLowerTriangular(tril) + affine = AffineLinearOperator(shift, scale) + # In this case, `forward` is equivalent to: + # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift + y = affine.forward(x) # [3., 7, 11] + ``` + + """ + + def __init__(self, + shift=None, + scale=None, + event_ndims=1, + validate_args=False, + name="affine_linear_operator"): + """Instantiates the `AffineLinearOperator` bijector. + + Args: + shift: Floating-point `Tensor`. + scale: Subclass of `LinearOperator`. Represents the (batch) positive + definite matrix `M` in `R^{k x k}`. + event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. Must be 0 or 1. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `event_ndims` is not 0 or 1. + TypeError: if `scale` is not a `LinearOperator`. + TypeError: if `shift.dtype` does not match `scale.dtype`. + ValueError: if not `scale.is_non_singular`. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + graph_parents = [] + with self._name_scope("init", values=[shift]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + if tensor_util.constant_value(event_ndims) is not None: + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims not in (0, 1): + raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) + else: + if validate_args: + # Shape tool will catch if event_ndims is negative. + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_less( + event_ndims, 2, message="event_ndims must be 0 or 1")], + event_ndims) + graph_parents += [event_ndims] + + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. + dtype = dtypes.float32 + + if shift is not None: + shift = ops.convert_to_tensor(shift, name="shift") + graph_parents += [shift] + dtype = shift.dtype.base_dtype + self._shift = shift + + if scale is not None: + if (shift is not None and + shift.dtype.base_dtype != scale.dtype.base_dtype): + raise TypeError( + "shift.dtype({}) is incompatible with scale.dtype({}).".format( + shift.dtype, scale.dtype)) + if not isinstance(scale, linear_operator.LinearOperator): + raise TypeError("scale is not an instance of tf.LinearOperator") + if validate_args and not scale.is_non_singular: + raise ValueError("Scale matrix must be non-singular.") + graph_parents += scale.graph_parents + if scale.tensor_rank is not None: + batch_ndims = scale.tensor_rank - 2 + else: + batch_ndims = scale.tensor_rank_tensor() - 2 + graph_parents += [batch_ndims] + if scale.dtype is not None: + dtype = scale.dtype.base_dtype + else: + batch_ndims = 0 # We won't need shape inference when scale is None. + self._scale = scale + self._shaper = _DistributionShape( + batch_ndims=batch_ndims, + event_ndims=event_ndims, + validate_args=validate_args) + super(AffineLinearOperator, self).__init__( + event_ndims=event_ndims, + graph_parents=graph_parents, + is_constant_jacobian=True, + dtype=dtype, + validate_args=validate_args, + name=name) + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = x + if self.scale is not None: + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + y, expand_batch_dim=False) + with ops.control_dependencies(self._maybe_collect_assertions() if + self.validate_args else []): + y = self.scale.matmul(y) + y = self._shaper.undo_make_batch_of_event_sample_matrices( + y, sample_shape, expand_batch_dim=False) + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = y + if self.shift is not None: + x -= self.shift + if self.scale is not None: + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + x, expand_batch_dim=False) + # Solve fails if the op is singular so we may safely skip this assertion. + x = self.scale.solve(x) + x = self._shaper.undo_make_batch_of_event_sample_matrices( + x, sample_shape, expand_batch_dim=False) + return x + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(y) + + def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument + if self.scale is None: + return constant_op.constant(0, dtype=x.dtype.base_dtype) + with ops.control_dependencies(self._maybe_collect_assertions() if + self.validate_args else []): + return self.scale.log_abs_determinant() + + def _maybe_collect_assertions(self): + try: + return [self.scale.assert_non_singular()] + except NotImplementedError: + pass + return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py deleted file mode 100644 index 89043b1410370074f11f2cfa59b6b6663fa62521..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""AffineLinearOperator bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.linalg import linear_operator - - -__all__ = [ - "AffineLinearOperator", -] - - -class AffineLinearOperator(bijector.Bijector): - """Compute `Y = g(X; shift, scale) = scale @ X + shift`. - - `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. - - If `X` is a scalar then the forward transformation is: `scale * X + shift` - where `*` denotes the scalar product. - - Note: we don't always simply transpose `X` (but write it this way for - brevity). Actually the input `X` undergoes the following transformation - before being premultiplied by `scale`: - - 1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e., - `new_sample_shape = [1]`. Otherwise do nothing. - 2. The sample shape is flattened to have one dimension, i.e., - `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`. - 3. The sample dim is cyclically rotated left by 1, i.e., - `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the - event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch - dimensions. - - (For more details see `shape.make_batch_of_event_sample_matrices`.) - - The result of the above transformation is that `X` can be regarded as a batch - of matrices where each column is a draw from the distribution. After - premultiplying by `scale`, we take the inverse of this procedure. The input - `Y` also undergoes the same transformation before/after premultiplying by - `inv(scale)`. - - Example Use: - - ```python - linalg = tf.linalg - - x = [1., 2, 3] - - shift = [-1., 0., 1] - diag = [1., 2, 3] - scale = linalg.LinearOperatorDiag(diag) - affine = AffineLinearOperator(shift, scale) - # In this case, `forward` is equivalent to: - # y = scale @ x + shift - y = affine.forward(x) # [0., 4, 10] - - shift = [2., 3, 1] - tril = [[1., 0, 0], - [2, 1, 0], - [3, 2, 1]] - scale = linalg.LinearOperatorLowerTriangular(tril) - affine = AffineLinearOperator(shift, scale) - # In this case, `forward` is equivalent to: - # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift - y = affine.forward(x) # [3., 7, 11] - ``` - - """ - - def __init__(self, - shift=None, - scale=None, - event_ndims=1, - validate_args=False, - name="affine_linear_operator"): - """Instantiates the `AffineLinearOperator` bijector. - - Args: - shift: Floating-point `Tensor`. - scale: Subclass of `LinearOperator`. Represents the (batch) positive - definite matrix `M` in `R^{k x k}`. - event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `event_ndims` is not 0 or 1. - TypeError: if `scale` is not a `LinearOperator`. - TypeError: if `shift.dtype` does not match `scale.dtype`. - ValueError: if not `scale.is_non_singular`. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - graph_parents = [] - with self._name_scope("init", values=[shift]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - if tensor_util.constant_value(event_ndims) is not None: - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims not in (0, 1): - raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - graph_parents += [event_ndims] - - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. - dtype = dtypes.float32 - - if shift is not None: - shift = ops.convert_to_tensor(shift, name="shift") - graph_parents += [shift] - dtype = shift.dtype.base_dtype - self._shift = shift - - if scale is not None: - if (shift is not None and - shift.dtype.base_dtype != scale.dtype.base_dtype): - raise TypeError( - "shift.dtype({}) is incompatible with scale.dtype({}).".format( - shift.dtype, scale.dtype)) - if not isinstance(scale, linear_operator.LinearOperator): - raise TypeError("scale is not an instance of tf.LinearOperator") - if validate_args and not scale.is_non_singular: - raise ValueError("Scale matrix must be non-singular.") - graph_parents += scale.graph_parents - if scale.tensor_rank is not None: - batch_ndims = scale.tensor_rank - 2 - else: - batch_ndims = scale.tensor_rank_tensor() - 2 - graph_parents += [batch_ndims] - if scale.dtype is not None: - dtype = scale.dtype.base_dtype - else: - batch_ndims = 0 # We won't need shape inference when scale is None. - self._scale = scale - self._shaper = _DistributionShape( - batch_ndims=batch_ndims, - event_ndims=event_ndims, - validate_args=validate_args) - super(AffineLinearOperator, self).__init__( - event_ndims=event_ndims, - graph_parents=graph_parents, - is_constant_jacobian=True, - dtype=dtype, - validate_args=validate_args, - name=name) - - @property - def shift(self): - """The `shift` `Tensor` in `Y = scale @ X + shift`.""" - return self._shift - - @property - def scale(self): - """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" - return self._scale - - def _forward(self, x): - y = x - if self.scale is not None: - y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - y, expand_batch_dim=False) - with ops.control_dependencies(self._maybe_collect_assertions() if - self.validate_args else []): - y = self.scale.matmul(y) - y = self._shaper.undo_make_batch_of_event_sample_matrices( - y, sample_shape, expand_batch_dim=False) - if self.shift is not None: - y += self.shift - return y - - def _inverse(self, y): - x = y - if self.shift is not None: - x -= self.shift - if self.scale is not None: - x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - x, expand_batch_dim=False) - # Solve fails if the op is singular so we may safely skip this assertion. - x = self.scale.solve(x) - x = self._shaper.undo_make_batch_of_event_sample_matrices( - x, sample_shape, expand_batch_dim=False) - return x - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument - if self.scale is None: - return constant_op.constant(0, dtype=x.dtype.base_dtype) - with ops.control_dependencies(self._maybe_collect_assertions() if - self.validate_args else []): - return self.scale.log_abs_determinant() - - def _maybe_collect_assertions(self): - try: - return [self.scale.assert_non_singular()] - except NotImplementedError: - pass - return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 0db10fb75c8483a8209f39370362b05a03d047ca..3ce7c26213034c7345a20faa803c94a1bfa8d579 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -18,12 +18,151 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.chain_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import itertools -_allowed_symbols = ["Chain"] +from tensorflow.python.framework import constant_op +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Chain", +] + + +class Chain(bijector.Bijector): + """Bijector which applies a sequence of bijectors. + + Example Use: + + ```python + chain = Chain([Exp(), Softplus()], name="one_plus_exp") + ``` + + Results in: + + * Forward: + + ```python + exp = Exp() + softplus = Softplus() + Chain([exp, softplus]).forward(x) + = exp.forward(softplus.forward(x)) + = tf.exp(tf.log(1. + tf.exp(x))) + = 1. + tf.exp(x) + ``` + + * Inverse: + + ```python + exp = Exp() + softplus = Softplus() + Chain([exp, softplus]).inverse(y) + = softplus.inverse(exp.inverse(y)) + = tf.log(tf.exp(tf.log(y)) - 1.) + = tf.log(y - 1.) + ``` + + """ + + def __init__(self, bijectors=None, validate_args=False, name=None): + """Instantiates `Chain` bijector. + + Args: + bijectors: Python `list` of bijector instances. An empty list makes this + bijector equivalent to the `Identity` bijector. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. Default: + E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`. + + Raises: + ValueError: if bijectors have different dtypes. + """ + if bijectors is None: + bijectors = () + self._bijectors = bijectors + + for a_bijector in bijectors: + if not a_bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijector ({})".format( + a_bijector.name)) + + dtype = list(set([b.dtype for b in bijectors])) + if len(dtype) > 2: + raise ValueError("incompatible dtypes: %s" % dtype) + elif len(dtype) == 2: + dtype = dtype[1] if dtype[0] is None else dtype[0] + event_ndims = bijectors[0].event_ndims + elif len(dtype) == 1: + dtype = dtype[0] + event_ndims = bijectors[0].event_ndims + else: + dtype = None + event_ndims = None + + super(Chain, self).__init__( + graph_parents=list(itertools.chain.from_iterable( + b.graph_parents for b in bijectors)), + is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), + validate_args=validate_args, + dtype=dtype, + event_ndims=event_ndims, + name=name or ("identity" if not bijectors else + "_of_".join(["chain"] + [b.name for b in bijectors]))) + + @property + def bijectors(self): + return self._bijectors + + def _shape_helper(self, func_name, input_shape, reverse): + new_shape = input_shape + for b in reversed(self.bijectors) if reverse else self.bijectors: + func = getattr(b, func_name, None) + if func is None: + raise ValueError("unable to call %s on bijector %s (%s)" % + (func_name, b.name, func)) + new_shape = func(new_shape) + return new_shape + + def _forward_event_shape(self, input_shape): + return self._shape_helper("forward_event_shape", input_shape, + reverse=True) + + def _forward_event_shape_tensor(self, input_shape): + return self._shape_helper( + "forward_event_shape_tensor", input_shape, reverse=True) + + def _inverse_event_shape(self, output_shape): + return self._shape_helper("inverse_event_shape", output_shape, + reverse=False) + + def _inverse_event_shape_tensor(self, output_shape): + return self._shape_helper("inverse_event_shape_tensor", output_shape, + reverse=False) + + def _inverse(self, y, **kwargs): + for b in self.bijectors: + y = b.inverse(y, **kwargs.get(b.name, {})) + return y + + def _inverse_log_det_jacobian(self, y, **kwargs): + ildj = constant_op.constant(0., dtype=y.dtype, + name="inverse_log_det_jacobian") + for b in self.bijectors: + ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) + y = b.inverse(y, **kwargs.get(b.name, {})) + return ildj + + def _forward(self, x, **kwargs): + for b in reversed(self.bijectors): + x = b.forward(x, **kwargs.get(b.name, {})) + return x + + def _forward_log_det_jacobian(self, x, **kwargs): + fldj = constant_op.constant(0., dtype=x.dtype, + name="forward_log_det_jacobian") + for b in reversed(self.bijectors): + fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) + x = b.forward(x, **kwargs.get(b.name, {})) + return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py deleted file mode 100644 index 3ce7c26213034c7345a20faa803c94a1bfa8d579..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Chain bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import itertools - -from tensorflow.python.framework import constant_op -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Chain", -] - - -class Chain(bijector.Bijector): - """Bijector which applies a sequence of bijectors. - - Example Use: - - ```python - chain = Chain([Exp(), Softplus()], name="one_plus_exp") - ``` - - Results in: - - * Forward: - - ```python - exp = Exp() - softplus = Softplus() - Chain([exp, softplus]).forward(x) - = exp.forward(softplus.forward(x)) - = tf.exp(tf.log(1. + tf.exp(x))) - = 1. + tf.exp(x) - ``` - - * Inverse: - - ```python - exp = Exp() - softplus = Softplus() - Chain([exp, softplus]).inverse(y) - = softplus.inverse(exp.inverse(y)) - = tf.log(tf.exp(tf.log(y)) - 1.) - = tf.log(y - 1.) - ``` - - """ - - def __init__(self, bijectors=None, validate_args=False, name=None): - """Instantiates `Chain` bijector. - - Args: - bijectors: Python `list` of bijector instances. An empty list makes this - bijector equivalent to the `Identity` bijector. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. Default: - E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`. - - Raises: - ValueError: if bijectors have different dtypes. - """ - if bijectors is None: - bijectors = () - self._bijectors = bijectors - - for a_bijector in bijectors: - if not a_bijector._is_injective: # pylint: disable=protected-access - raise NotImplementedError( - "Invert is not implemented for non-injective bijector ({})".format( - a_bijector.name)) - - dtype = list(set([b.dtype for b in bijectors])) - if len(dtype) > 2: - raise ValueError("incompatible dtypes: %s" % dtype) - elif len(dtype) == 2: - dtype = dtype[1] if dtype[0] is None else dtype[0] - event_ndims = bijectors[0].event_ndims - elif len(dtype) == 1: - dtype = dtype[0] - event_ndims = bijectors[0].event_ndims - else: - dtype = None - event_ndims = None - - super(Chain, self).__init__( - graph_parents=list(itertools.chain.from_iterable( - b.graph_parents for b in bijectors)), - is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), - validate_args=validate_args, - dtype=dtype, - event_ndims=event_ndims, - name=name or ("identity" if not bijectors else - "_of_".join(["chain"] + [b.name for b in bijectors]))) - - @property - def bijectors(self): - return self._bijectors - - def _shape_helper(self, func_name, input_shape, reverse): - new_shape = input_shape - for b in reversed(self.bijectors) if reverse else self.bijectors: - func = getattr(b, func_name, None) - if func is None: - raise ValueError("unable to call %s on bijector %s (%s)" % - (func_name, b.name, func)) - new_shape = func(new_shape) - return new_shape - - def _forward_event_shape(self, input_shape): - return self._shape_helper("forward_event_shape", input_shape, - reverse=True) - - def _forward_event_shape_tensor(self, input_shape): - return self._shape_helper( - "forward_event_shape_tensor", input_shape, reverse=True) - - def _inverse_event_shape(self, output_shape): - return self._shape_helper("inverse_event_shape", output_shape, - reverse=False) - - def _inverse_event_shape_tensor(self, output_shape): - return self._shape_helper("inverse_event_shape_tensor", output_shape, - reverse=False) - - def _inverse(self, y, **kwargs): - for b in self.bijectors: - y = b.inverse(y, **kwargs.get(b.name, {})) - return y - - def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant(0., dtype=y.dtype, - name="inverse_log_det_jacobian") - for b in self.bijectors: - ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) - y = b.inverse(y, **kwargs.get(b.name, {})) - return ildj - - def _forward(self, x, **kwargs): - for b in reversed(self.bijectors): - x = b.forward(x, **kwargs.get(b.name, {})) - return x - - def _forward_log_det_jacobian(self, x, **kwargs): - fldj = constant_op.constant(0., dtype=x.dtype, - name="forward_log_det_jacobian") - for b in reversed(self.bijectors): - fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) - x = b.forward(x, **kwargs.get(b.name, {})) - return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 4686af8bc42a3232cb3a34f2cfcce8323c5896dd..cbd60f92a60612c6cf791b2c7708a3310c6e2b6b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -18,12 +18,219 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["CholeskyOuterProduct"] +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "CholeskyOuterProduct", +] + + +class CholeskyOuterProduct(bijector.Bijector): + """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. + + `event_ndims` must be 0 or 2, i.e., scalar or matrix. + + Note: the upper-triangular part of X is ignored (whether or not its zero). + + The surjectivity of g as a map from the set of n x n positive-diagonal + lower-triangular matrices to the set of SPD matrices follows immediately from + executing the Cholesky factorization algorithm on an SPD matrix A to produce a + positive-diagonal lower-triangular matrix L such that `A = L @ L.T`. + + To prove the injectivity of g, suppose that L_1 and L_2 are lower-triangular + with positive diagonals and satisfy `A = L_1 @ L_1.T = L_2 @ L_2.T`. Then + `inv(L_1) @ A @ inv(L_1).T = [inv(L_1) @ L_2] @ [inv(L_1) @ L_2].T = I`. + Setting `L_3 := inv(L_1) @ L_2`, that L_3 is a positive-diagonal + lower-triangular matrix follows from `inv(L_1)` being positive-diagonal + lower-triangular (which follows from the diagonal of a triangular matrix being + its spectrum), and that the product of two positive-diagonal lower-triangular + matrices is another positive-diagonal lower-triangular matrix. + + A simple inductive argument (proceding one column of L_3 at a time) shows + that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- + diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. + + Examples: + + ```python + bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 2], [2, 5]], i.e., x @ x.T + + bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) + # Result: [[1., 0], [2, 1]], i.e., cholesky(y). + ``` + + """ + + def __init__(self, event_ndims=2, validate_args=False, + name="cholesky_outer_product"): + """Instantiates the `CholeskyOuterProduct` bijector. + + Args: + event_ndims: `constant` `int32` scalar `Tensor` indicating the number of + dimensions associated with a particular draw from the distribution. Must + be 0 or 2. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if event_ndims is neither 0 or 2. + """ + self._graph_parents = [] + self._name = name + with self._name_scope("init", values=[event_ndims]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims is None or event_ndims not in [0, 2]: + raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") + self._static_event_ndims = event_ndims + super(CholeskyOuterProduct, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self._static_event_ndims == 0: + return math_ops.square(x) + if self.validate_args: + is_matrix = check_ops.assert_rank_at_least(x, 2) + shape = array_ops.shape(x) + is_square = check_ops.assert_equal(shape[-2], shape[-1]) + x = control_flow_ops.with_dependencies([is_matrix, is_square], x) + # For safety, explicitly zero-out the upper triangular part. + x = array_ops.matrix_band_part(x, -1, 0) + return math_ops.matmul(x, x, adjoint_b=True) + + def _inverse(self, y): + return (math_ops.sqrt(y) if self._static_event_ndims == 0 + else linalg_ops.cholesky(y)) + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(x=self._inverse(y)) + + def _forward_log_det_jacobian(self, x): + # Let Y be a symmetric, positive definite matrix and write: + # Y = X X.T + # where X is lower-triangular. + # + # Observe that, + # dY[i,j]/dX[a,b] + # = d/dX[a,b] { X[i,:] X[j,:] } + # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } + # + # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is + # symmetric and X is lower-triangular, we need vectors of dimension: + # d = p (p + 1) / 2 + # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., + # k = { i (i + 1) / 2 + j i>=j + # { undef ij thus i,j!=a. + # + # Since the Jacobian is lower-triangular, we need only compute the product + # of diagonal elements: + # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] + # = X[j,j] + I[i=j] X[i,j] + # = 2 X[j,j]. + # Since there is a 2 X[j,j] term for every lower-triangular element of X we + # conclude: + # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. + if self._static_event_ndims == 0: + if self.validate_args: + is_positive = check_ops.assert_positive( + x, message="All elements must be positive.") + x = control_flow_ops.with_dependencies([is_positive], x) + return np.log(2.) + math_ops.log(x) + + diag = array_ops.matrix_diag_part(x) + + # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output + # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the + # output is unchanged. + diag = self._make_columnar(diag) + + if self.validate_args: + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must be a (batch of) matrix.") + shape = array_ops.shape(x) + is_square = check_ops.assert_equal( + shape[-2], shape[-1], + message="Input must be a (batch of) square matrix.") + # Assuming lower-triangular means we only need check diag>0. + is_positive_definite = check_ops.assert_positive( + diag, message="Input must be positive definite.") + x = control_flow_ops.with_dependencies( + [is_matrix, is_square, is_positive_definite], x) + + # Create a vector equal to: [p, p-1, ..., 2, 1]. + if x.get_shape().ndims is None or x.get_shape()[-1].value is None: + p_int = array_ops.shape(x)[-1] + p_float = math_ops.cast(p_int, dtype=x.dtype) + else: + p_int = x.get_shape()[-1].value + p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) + exponents = math_ops.linspace(p_float, 1., p_int) + + sum_weighted_log_diag = array_ops.squeeze( + math_ops.matmul(math_ops.log(diag), + exponents[..., array_ops.newaxis]), + squeeze_dims=-1) + fldj = p_float * np.log(2.) + sum_weighted_log_diag + + return fldj + + def _make_columnar(self, x): + """Ensures non-scalar input has at least one column. + + Example: + If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. + + If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. + + If `x = 1` then the output is unchanged. + + Args: + x: `Tensor`. + + Returns: + columnar_x: `Tensor` with at least two dimensions. + """ + if x.get_shape().ndims is not None: + if x.get_shape().ndims == 1: + x = x[array_ops.newaxis, :] + return x + shape = array_ops.shape(x) + maybe_expanded_shape = array_ops.concat([ + shape[:-1], + distribution_util.pick_vector( + math_ops.equal(array_ops.rank(x), 1), + [1], np.array([], dtype=np.int32)), + shape[-1:], + ], 0) + return array_ops.reshape(x, maybe_expanded_shape) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py deleted file mode 100644 index cbd60f92a60612c6cf791b2c7708a3310c6e2b6b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""CholeskyOuterProduct bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = [ - "CholeskyOuterProduct", -] - - -class CholeskyOuterProduct(bijector.Bijector): - """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. - - `event_ndims` must be 0 or 2, i.e., scalar or matrix. - - Note: the upper-triangular part of X is ignored (whether or not its zero). - - The surjectivity of g as a map from the set of n x n positive-diagonal - lower-triangular matrices to the set of SPD matrices follows immediately from - executing the Cholesky factorization algorithm on an SPD matrix A to produce a - positive-diagonal lower-triangular matrix L such that `A = L @ L.T`. - - To prove the injectivity of g, suppose that L_1 and L_2 are lower-triangular - with positive diagonals and satisfy `A = L_1 @ L_1.T = L_2 @ L_2.T`. Then - `inv(L_1) @ A @ inv(L_1).T = [inv(L_1) @ L_2] @ [inv(L_1) @ L_2].T = I`. - Setting `L_3 := inv(L_1) @ L_2`, that L_3 is a positive-diagonal - lower-triangular matrix follows from `inv(L_1)` being positive-diagonal - lower-triangular (which follows from the diagonal of a triangular matrix being - its spectrum), and that the product of two positive-diagonal lower-triangular - matrices is another positive-diagonal lower-triangular matrix. - - A simple inductive argument (proceding one column of L_3 at a time) shows - that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- - diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. - - Examples: - - ```python - bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) - # Result: [[1., 2], [2, 5]], i.e., x @ x.T - - bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) - # Result: [[1., 0], [2, 1]], i.e., cholesky(y). - ``` - - """ - - def __init__(self, event_ndims=2, validate_args=False, - name="cholesky_outer_product"): - """Instantiates the `CholeskyOuterProduct` bijector. - - Args: - event_ndims: `constant` `int32` scalar `Tensor` indicating the number of - dimensions associated with a particular draw from the distribution. Must - be 0 or 2. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if event_ndims is neither 0 or 2. - """ - self._graph_parents = [] - self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 2]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") - self._static_event_ndims = event_ndims - super(CholeskyOuterProduct, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - if self._static_event_ndims == 0: - return math_ops.square(x) - if self.validate_args: - is_matrix = check_ops.assert_rank_at_least(x, 2) - shape = array_ops.shape(x) - is_square = check_ops.assert_equal(shape[-2], shape[-1]) - x = control_flow_ops.with_dependencies([is_matrix, is_square], x) - # For safety, explicitly zero-out the upper triangular part. - x = array_ops.matrix_band_part(x, -1, 0) - return math_ops.matmul(x, x, adjoint_b=True) - - def _inverse(self, y): - return (math_ops.sqrt(y) if self._static_event_ndims == 0 - else linalg_ops.cholesky(y)) - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(x=self._inverse(y)) - - def _forward_log_det_jacobian(self, x): - # Let Y be a symmetric, positive definite matrix and write: - # Y = X X.T - # where X is lower-triangular. - # - # Observe that, - # dY[i,j]/dX[a,b] - # = d/dX[a,b] { X[i,:] X[j,:] } - # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } - # - # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is - # symmetric and X is lower-triangular, we need vectors of dimension: - # d = p (p + 1) / 2 - # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., - # k = { i (i + 1) / 2 + j i>=j - # { undef ij thus i,j!=a. - # - # Since the Jacobian is lower-triangular, we need only compute the product - # of diagonal elements: - # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] - # = X[j,j] + I[i=j] X[i,j] - # = 2 X[j,j]. - # Since there is a 2 X[j,j] term for every lower-triangular element of X we - # conclude: - # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. - if self._static_event_ndims == 0: - if self.validate_args: - is_positive = check_ops.assert_positive( - x, message="All elements must be positive.") - x = control_flow_ops.with_dependencies([is_positive], x) - return np.log(2.) + math_ops.log(x) - - diag = array_ops.matrix_diag_part(x) - - # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output - # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the - # output is unchanged. - diag = self._make_columnar(diag) - - if self.validate_args: - is_matrix = check_ops.assert_rank_at_least( - x, 2, message="Input must be a (batch of) matrix.") - shape = array_ops.shape(x) - is_square = check_ops.assert_equal( - shape[-2], shape[-1], - message="Input must be a (batch of) square matrix.") - # Assuming lower-triangular means we only need check diag>0. - is_positive_definite = check_ops.assert_positive( - diag, message="Input must be positive definite.") - x = control_flow_ops.with_dependencies( - [is_matrix, is_square, is_positive_definite], x) - - # Create a vector equal to: [p, p-1, ..., 2, 1]. - if x.get_shape().ndims is None or x.get_shape()[-1].value is None: - p_int = array_ops.shape(x)[-1] - p_float = math_ops.cast(p_int, dtype=x.dtype) - else: - p_int = x.get_shape()[-1].value - p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) - exponents = math_ops.linspace(p_float, 1., p_int) - - sum_weighted_log_diag = array_ops.squeeze( - math_ops.matmul(math_ops.log(diag), - exponents[..., array_ops.newaxis]), - squeeze_dims=-1) - fldj = p_float * np.log(2.) + sum_weighted_log_diag - - return fldj - - def _make_columnar(self, x): - """Ensures non-scalar input has at least one column. - - Example: - If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. - - If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. - - If `x = 1` then the output is unchanged. - - Args: - x: `Tensor`. - - Returns: - columnar_x: `Tensor` with at least two dimensions. - """ - if x.get_shape().ndims is not None: - if x.get_shape().ndims == 1: - x = x[array_ops.newaxis, :] - return x - shape = array_ops.shape(x) - maybe_expanded_shape = array_ops.concat([ - shape[:-1], - distribution_util.pick_vector( - math_ops.equal(array_ops.rank(x), 1), - [1], np.array([], dtype=np.int32)), - shape[-1:], - ], 0) - return array_ops.reshape(x, maybe_expanded_shape) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py index d254b635d28099a09a2054536f04ffee3a355b2f..ccb1f029277bc07011df7be047a075274f2b3a27 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py @@ -18,12 +18,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -_allowed_symbols = ["ConditionalBijector"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = ["ConditionalBijector"] + + +class ConditionalBijector(bijector.Bijector): + """Conditional Bijector is a Bijector that allows intrinsic conditioning.""" + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def forward(self, x, name="forward", **condition_kwargs): + return self._call_forward(x, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def inverse(self, y, name="inverse", **condition_kwargs): + return self._call_inverse(y, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def inverse_log_det_jacobian( + self, y, name="inverse_log_det_jacobian", **condition_kwargs): + return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def forward_log_det_jacobian( + self, x, name="forward_log_det_jacobian", **condition_kwargs): + return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py deleted file mode 100644 index ccb1f029277bc07011df7be047a075274f2b3a27..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""ConditionalBijector base.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = ["ConditionalBijector"] - - -class ConditionalBijector(bijector.Bijector): - """Conditional Bijector is a Bijector that allows intrinsic conditioning.""" - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def forward(self, x, name="forward", **condition_kwargs): - return self._call_forward(x, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def inverse(self, y, name="inverse", **condition_kwargs): - return self._call_inverse(y, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def inverse_log_det_jacobian( - self, y, name="inverse_log_det_jacobian", **condition_kwargs): - return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def forward_log_det_jacobian( - self, x, name="forward_log_det_jacobian", **condition_kwargs): - return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index 399d713098eb7223601beb9518dc51dd6160ad64..b1ff840d62a73c941a4d67dec73b5c9f4d5353f9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -18,12 +18,49 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.exp_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.bijectors import power_transform -_allowed_symbols = ["Exp"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Exp", +] + + +class Exp(power_transform.PowerTransform): + """Compute `Y = g(X) = exp(X)`. + + Example Use: + + ```python + # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 + # batch ndim and 2 event ndims (i.e., vector of matrices). + exp = Exp(event_ndims=2) + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + exp(x) == exp.forward(x) + log(x) == exp.inverse(x) + ``` + + Note: the exp(.) is applied element-wise but the Jacobian is a reduction + over the event space. + """ + + def __init__(self, + event_ndims=0, + validate_args=False, + name="exp"): + """Instantiates the `Exp` bijector. + + Args: + event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + super(Exp, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py deleted file mode 100644 index b1ff840d62a73c941a4d67dec73b5c9f4d5353f9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Exp bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.bijectors import power_transform - - -__all__ = [ - "Exp", -] - - -class Exp(power_transform.PowerTransform): - """Compute `Y = g(X) = exp(X)`. - - Example Use: - - ```python - # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - exp = Exp(event_ndims=2) - x = [[[1., 2], - [3, 4]], - [[5, 6], - [7, 8]]] - exp(x) == exp.forward(x) - log(x) == exp.inverse(x) - ``` - - Note: the exp(.) is applied element-wise but the Jacobian is a reduction - over the event space. - """ - - def __init__(self, - event_ndims=0, - validate_args=False, - name="exp"): - """Instantiates the `Exp` bijector. - - Args: - event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - super(Exp, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index cf37aa51115ed98ab263bc03bcb297a03432a7ae..67f39785563255be0fe154aca3cbcf01c6a01e73 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -18,12 +18,107 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.gumbel_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Gumbel"] +__all__ = [ + "Gumbel", +] -remove_undocumented(__name__, _allowed_symbols) + +class Gumbel(bijector.Bijector): + """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + + This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): + + ```none + Y ~ Gumbel(loc, scale) + pdf(y; loc, scale) = exp( + -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale + ``` + """ + + def __init__(self, + loc=0., + scale=1., + event_ndims=0, + validate_args=False, + name="gumbel"): + """Instantiates the `Gumbel` bijector. + + Args: + loc: Float-like `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + scale: Positive Float-like `Tensor` that is the same dtype and is + broadcastable with `loc`. + This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[loc, scale]): + self._loc = ops.convert_to_tensor(loc, name="loc") + self._scale = ops.convert_to_tensor(scale, name="scale") + check_ops.assert_same_float_dtype([self._loc, self._scale]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, message="Argument scale was not positive") + ], self._scale) + + super(Gumbel, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def loc(self): + """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._loc + + @property + def scale(self): + """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._scale + + def _forward(self, x): + z = (x - self.loc) / self.scale + return math_ops.exp(-math_ops.exp(-z)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.loc - self.scale * math_ops.log(-math_ops.log(y)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + event_dims = self._event_dims_tensor(x) + z = (x - self.loc) / self.scale + return math_ops.reduce_sum( + -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, + constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py deleted file mode 100644 index 67f39785563255be0fe154aca3cbcf01c6a01e73..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Gumbel bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "Gumbel", -] - - -class Gumbel(bijector.Bijector): - """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - - This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the - bijector applied to a uniform random variable `X ~ U(0, 1) gives back a - random variable with the - [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): - - ```none - Y ~ Gumbel(loc, scale) - pdf(y; loc, scale) = exp( - -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale - ``` - """ - - def __init__(self, - loc=0., - scale=1., - event_ndims=0, - validate_args=False, - name="gumbel"): - """Instantiates the `Gumbel` bijector. - - Args: - loc: Float-like `Tensor` that is the same dtype and is - broadcastable with `scale`. - This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - scale: Positive Float-like `Tensor` that is the same dtype and is - broadcastable with `loc`. - This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[loc, scale]): - self._loc = ops.convert_to_tensor(loc, name="loc") - self._scale = ops.convert_to_tensor(scale, name="scale") - check_ops.assert_same_float_dtype([self._loc, self._scale]) - if validate_args: - self._scale = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._scale, message="Argument scale was not positive") - ], self._scale) - - super(Gumbel, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) - - @property - def loc(self): - """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" - return self._loc - - @property - def scale(self): - """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" - return self._scale - - def _forward(self, x): - z = (x - self.loc) / self.scale - return math_ops.exp(-math_ops.exp(-z)) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - return self.loc - self.scale * math_ops.log(-math_ops.log(y)) - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) - - def _forward_log_det_jacobian(self, x): - event_dims = self._event_dims_tensor(x) - z = (x - self.loc) / self.scale - return math_ops.reduce_sum( - -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_positive = check_ops.assert_non_negative( - y, message="Inverse transformation input must be greater than 0.") - less_than_one = check_ops.assert_less_equal( - y, - constant_op.constant(1., y.dtype), - message="Inverse transformation input must be less than or equal to 1.") - return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index db10c3fc3a9135b4c408ada74622ba9b360f9ec1..fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -18,12 +18,124 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.inline_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Inline"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Inline", +] + + +class Inline(bijector.Bijector): + """Bijector constructed from custom callables. + + Example Use: + + ```python + exp = Inline( + forward_fn=tf.exp, + inverse_fn=tf.log, + inverse_log_det_jacobian_fn=( + lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), + name="exp") + ``` + + The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. + """ + + def __init__(self, + forward_fn=None, + inverse_fn=None, + inverse_log_det_jacobian_fn=None, + forward_log_det_jacobian_fn=None, + forward_event_shape_fn=None, + forward_event_shape_tensor_fn=None, + inverse_event_shape_fn=None, + inverse_event_shape_tensor_fn=None, + is_constant_jacobian=False, + validate_args=False, + name="inline"): + """Creates a `Bijector` from callables. + + Args: + forward_fn: Python callable implementing the forward transformation. + inverse_fn: Python callable implementing the inverse transformation. + inverse_log_det_jacobian_fn: Python callable implementing the + log o det o jacobian of the inverse transformation. + forward_log_det_jacobian_fn: Python callable implementing the + log o det o jacobian of the forward transformation. + forward_event_shape_fn: Python callable implementing non-identical + static event shape changes. Default: shape is assumed unchanged. + forward_event_shape_tensor_fn: Python callable implementing non-identical + event shape changes. Default: shape is assumed unchanged. + inverse_event_shape_fn: Python callable implementing non-identical + static event shape changes. Default: shape is assumed unchanged. + inverse_event_shape_tensor_fn: Python callable implementing non-identical + event shape changes. Default: shape is assumed unchanged. + is_constant_jacobian: Python `bool` indicating that the Jacobian is + constant for all input arguments. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + super(Inline, self).__init__( + event_ndims=0, + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + self._forward_fn = forward_fn + self._inverse_fn = inverse_fn + self._inverse_log_det_jacobian_fn = inverse_log_det_jacobian_fn + self._forward_log_det_jacobian_fn = forward_log_det_jacobian_fn + self._forward_event_shape_fn = forward_event_shape_fn + self._forward_event_shape_tensor_fn = forward_event_shape_tensor_fn + self._inverse_event_shape_fn = inverse_event_shape_fn + self._inverse_event_shape_tensor_fn = inverse_event_shape_tensor_fn + + def _forward_event_shape(self, input_shape): + if self._forward_event_shape_fn is None: + # By default assume shape doesn't change. + return input_shape + return self._forward_event_shape_fn(input_shape) + + def _forward_event_shape_tensor(self, input_shape): + if self._forward_event_shape_tensor_fn is None: + # By default assume shape doesn't change. + return input_shape + return self._forward_event_shape_tensor_fn(input_shape) + + def _inverse_event_shape(self, output_shape): + if self._inverse_event_shape_fn is None: + # By default assume shape doesn't change. + return output_shape + return self._inverse_event_shape_fn(output_shape) + + def _inverse_event_shape_tensor(self, output_shape): + if self._inverse_event_shape_tensor_fn is None: + # By default assume shape doesn't change. + return output_shape + return self._inverse_event_shape_tensor_fn(output_shape) + + def _forward(self, x, **kwargs): + if not callable(self._forward_fn): + raise NotImplementedError( + "forward_fn is not a callable function.") + return self._forward_fn(x, **kwargs) + + def _inverse(self, y, **kwargs): + if not callable(self._inverse_fn): + raise NotImplementedError( + "inverse_fn is not a callable function.") + return self._inverse_fn(y, **kwargs) + + def _inverse_log_det_jacobian(self, y, **kwargs): + if not callable(self._inverse_log_det_jacobian_fn): + raise NotImplementedError( + "inverse_log_det_jacobian_fn is not a callable function.") + return self._inverse_log_det_jacobian_fn(y, **kwargs) + + def _forward_log_det_jacobian(self, y, **kwargs): + if not callable(self._forward_log_det_jacobian_fn): + raise NotImplementedError( + "forward_log_det_jacobian_fn is not a callable function.") + return self._forward_log_det_jacobian_fn(y, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py deleted file mode 100644 index fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Inline bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Inline", -] - - -class Inline(bijector.Bijector): - """Bijector constructed from custom callables. - - Example Use: - - ```python - exp = Inline( - forward_fn=tf.exp, - inverse_fn=tf.log, - inverse_log_det_jacobian_fn=( - lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), - name="exp") - ``` - - The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. - """ - - def __init__(self, - forward_fn=None, - inverse_fn=None, - inverse_log_det_jacobian_fn=None, - forward_log_det_jacobian_fn=None, - forward_event_shape_fn=None, - forward_event_shape_tensor_fn=None, - inverse_event_shape_fn=None, - inverse_event_shape_tensor_fn=None, - is_constant_jacobian=False, - validate_args=False, - name="inline"): - """Creates a `Bijector` from callables. - - Args: - forward_fn: Python callable implementing the forward transformation. - inverse_fn: Python callable implementing the inverse transformation. - inverse_log_det_jacobian_fn: Python callable implementing the - log o det o jacobian of the inverse transformation. - forward_log_det_jacobian_fn: Python callable implementing the - log o det o jacobian of the forward transformation. - forward_event_shape_fn: Python callable implementing non-identical - static event shape changes. Default: shape is assumed unchanged. - forward_event_shape_tensor_fn: Python callable implementing non-identical - event shape changes. Default: shape is assumed unchanged. - inverse_event_shape_fn: Python callable implementing non-identical - static event shape changes. Default: shape is assumed unchanged. - inverse_event_shape_tensor_fn: Python callable implementing non-identical - event shape changes. Default: shape is assumed unchanged. - is_constant_jacobian: Python `bool` indicating that the Jacobian is - constant for all input arguments. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - super(Inline, self).__init__( - event_ndims=0, - is_constant_jacobian=is_constant_jacobian, - validate_args=validate_args, - name=name) - self._forward_fn = forward_fn - self._inverse_fn = inverse_fn - self._inverse_log_det_jacobian_fn = inverse_log_det_jacobian_fn - self._forward_log_det_jacobian_fn = forward_log_det_jacobian_fn - self._forward_event_shape_fn = forward_event_shape_fn - self._forward_event_shape_tensor_fn = forward_event_shape_tensor_fn - self._inverse_event_shape_fn = inverse_event_shape_fn - self._inverse_event_shape_tensor_fn = inverse_event_shape_tensor_fn - - def _forward_event_shape(self, input_shape): - if self._forward_event_shape_fn is None: - # By default assume shape doesn't change. - return input_shape - return self._forward_event_shape_fn(input_shape) - - def _forward_event_shape_tensor(self, input_shape): - if self._forward_event_shape_tensor_fn is None: - # By default assume shape doesn't change. - return input_shape - return self._forward_event_shape_tensor_fn(input_shape) - - def _inverse_event_shape(self, output_shape): - if self._inverse_event_shape_fn is None: - # By default assume shape doesn't change. - return output_shape - return self._inverse_event_shape_fn(output_shape) - - def _inverse_event_shape_tensor(self, output_shape): - if self._inverse_event_shape_tensor_fn is None: - # By default assume shape doesn't change. - return output_shape - return self._inverse_event_shape_tensor_fn(output_shape) - - def _forward(self, x, **kwargs): - if not callable(self._forward_fn): - raise NotImplementedError( - "forward_fn is not a callable function.") - return self._forward_fn(x, **kwargs) - - def _inverse(self, y, **kwargs): - if not callable(self._inverse_fn): - raise NotImplementedError( - "inverse_fn is not a callable function.") - return self._inverse_fn(y, **kwargs) - - def _inverse_log_det_jacobian(self, y, **kwargs): - if not callable(self._inverse_log_det_jacobian_fn): - raise NotImplementedError( - "inverse_log_det_jacobian_fn is not a callable function.") - return self._inverse_log_det_jacobian_fn(y, **kwargs) - - def _forward_log_det_jacobian(self, y, **kwargs): - if not callable(self._forward_log_det_jacobian_fn): - raise NotImplementedError( - "forward_log_det_jacobian_fn is not a callable function.") - return self._forward_log_det_jacobian_fn(y, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index c134e10109ce5065eb58de1d847e3c487258954c..2c603fe61f36dd27f4984fe6c13c11f2fb534321 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -18,12 +18,85 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.invert_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector as bijector_lib -_allowed_symbols = ["Invert"] +__all__ = [ + "Invert", +] -remove_undocumented(__name__, _allowed_symbols) + +class Invert(bijector_lib.Bijector): + """Bijector which inverts another Bijector. + + Example Use: [ExpGammaDistribution (see Background & Context)]( + https://reference.wolfram.com/language/ref/ExpGammaDistribution.html) + models `Y=log(X)` where `X ~ Gamma`. + + ```python + exp_gamma_distribution = TransformedDistribution( + distribution=Gamma(concentration=1., rate=2.), + bijector=bijector.Invert(bijector.Exp()) + ``` + + """ + + def __init__(self, bijector, validate_args=False, name=None): + """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. + + Note: An inverted bijector's `inverse_log_det_jacobian` is often more + efficient if the base bijector implements `_forward_log_det_jacobian`. If + `_forward_log_det_jacobian` is not implemented then the following code is + used: + + ```python + y = self.inverse(x, **kwargs) + return -self.inverse_log_det_jacobian(y, **kwargs) + ``` + + Args: + bijector: Bijector instance. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + + if not bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijectors.") + + self._bijector = bijector + super(Invert, self).__init__( + event_ndims=bijector.event_ndims, + graph_parents=bijector.graph_parents, + is_constant_jacobian=bijector.is_constant_jacobian, + validate_args=validate_args, + dtype=bijector.dtype, + name=name or "_".join(["invert", bijector.name])) + + def _forward_event_shape(self, input_shape): + return self.bijector._inverse_event_shape(input_shape) # pylint: disable=protected-access + + def _forward_event_shape_tensor(self, input_shape): + return self.bijector._inverse_event_shape_tensor(input_shape) # pylint: disable=protected-access + + def _inverse_event_shape(self, output_shape): + return self.bijector._forward_event_shape(output_shape) # pylint: disable=protected-access + + def _inverse_event_shape_tensor(self, output_shape): + return self.bijector._forward_event_shape_tensor(output_shape) # pylint: disable=protected-access + + @property + def bijector(self): + return self._bijector + + def _forward(self, x, **kwargs): + return self.bijector._inverse(x, **kwargs) # pylint: disable=protected-access + + def _inverse(self, y, **kwargs): + return self.bijector._forward(y, **kwargs) # pylint: disable=protected-access + + def _inverse_log_det_jacobian(self, y, **kwargs): + return self.bijector._forward_log_det_jacobian(y, **kwargs) # pylint: disable=protected-access + + def _forward_log_det_jacobian(self, x, **kwargs): + return self.bijector._inverse_log_det_jacobian(x, **kwargs) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py deleted file mode 100644 index 2c603fe61f36dd27f4984fe6c13c11f2fb534321..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Invert bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector as bijector_lib - -__all__ = [ - "Invert", -] - - -class Invert(bijector_lib.Bijector): - """Bijector which inverts another Bijector. - - Example Use: [ExpGammaDistribution (see Background & Context)]( - https://reference.wolfram.com/language/ref/ExpGammaDistribution.html) - models `Y=log(X)` where `X ~ Gamma`. - - ```python - exp_gamma_distribution = TransformedDistribution( - distribution=Gamma(concentration=1., rate=2.), - bijector=bijector.Invert(bijector.Exp()) - ``` - - """ - - def __init__(self, bijector, validate_args=False, name=None): - """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. - - Note: An inverted bijector's `inverse_log_det_jacobian` is often more - efficient if the base bijector implements `_forward_log_det_jacobian`. If - `_forward_log_det_jacobian` is not implemented then the following code is - used: - - ```python - y = self.inverse(x, **kwargs) - return -self.inverse_log_det_jacobian(y, **kwargs) - ``` - - Args: - bijector: Bijector instance. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - - if not bijector._is_injective: # pylint: disable=protected-access - raise NotImplementedError( - "Invert is not implemented for non-injective bijectors.") - - self._bijector = bijector - super(Invert, self).__init__( - event_ndims=bijector.event_ndims, - graph_parents=bijector.graph_parents, - is_constant_jacobian=bijector.is_constant_jacobian, - validate_args=validate_args, - dtype=bijector.dtype, - name=name or "_".join(["invert", bijector.name])) - - def _forward_event_shape(self, input_shape): - return self.bijector._inverse_event_shape(input_shape) # pylint: disable=protected-access - - def _forward_event_shape_tensor(self, input_shape): - return self.bijector._inverse_event_shape_tensor(input_shape) # pylint: disable=protected-access - - def _inverse_event_shape(self, output_shape): - return self.bijector._forward_event_shape(output_shape) # pylint: disable=protected-access - - def _inverse_event_shape_tensor(self, output_shape): - return self.bijector._forward_event_shape_tensor(output_shape) # pylint: disable=protected-access - - @property - def bijector(self): - return self._bijector - - def _forward(self, x, **kwargs): - return self.bijector._inverse(x, **kwargs) # pylint: disable=protected-access - - def _inverse(self, y, **kwargs): - return self.bijector._forward(y, **kwargs) # pylint: disable=protected-access - - def _inverse_log_det_jacobian(self, y, **kwargs): - return self.bijector._forward_log_det_jacobian(y, **kwargs) # pylint: disable=protected-access - - def _forward_log_det_jacobian(self, x, **kwargs): - return self.bijector._inverse_log_det_jacobian(x, **kwargs) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 132dc570f94719b6c71fb269866c943774481b7e..06c7c61ec3dc3980e0d12a984739dca5a925ac9f 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -18,16 +18,459 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = [ +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import template as template_ops +from tensorflow.python.ops import variable_scope as variable_scope_lib +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ "MaskedAutoregressiveFlow", - "masked_dense", "masked_autoregressive_default_template", + "masked_dense", ] -remove_undocumented(__name__, _allowed_symbols) + +class MaskedAutoregressiveFlow(bijector_lib.Bijector): + """Affine MaskedAutoregressiveFlow bijector for vector-valued events. + + The affine autoregressive flow [1] provides a relatively simple framework for + user-specified (deep) architectures to learn a distribution over vector-valued + events. Regarding terminology, + + "Autoregressive models decompose the joint density as a product of + conditionals, and model each conditional in turn. Normalizing flows + transform a base density (e.g. a standard Gaussian) into the target density + by an invertible transformation with tractable Jacobian." [1] + + In other words, the "autoregressive property" is equivalent to the + decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided + `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves + this property by zeroing out weights in its `masked_dense` layers. + + In the `tf.distributions` framework, a "normalizing flow" is implemented as a + `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" + is implemented using a `tf.while_loop` and a deep neural network (DNN) with + masked weights such that the autoregressive property is automatically met in + the `inverse`. + + A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the + (expensive) forward-mode calculation to draw samples and the (cheap) + reverse-mode calculation to compute log-probabilities. Conversely, a + `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses + the (expensive) forward-mode calculation to compute log-probabilities and the + (cheap) reverse-mode calculation to compute samples. See "Example Use" + [below] for more details. + + Given a `shift_and_log_scale_fn`, the forward and inverse transformations are + (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` + must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka + "alpha" [2]) such that each are broadcastable with the arguments to `forward` + and `inverse`, i.e., such that the calculations in `forward`, `inverse` + [below] are possible. + + For convenience, `masked_autoregressive_default_template` is offered as a + possible `shift_and_log_scale_fn` function. It implements the MADE + architecture [2]. MADE is a feed-forward network that computes a `shift` and + `log(scale)` using `masked_dense` layers in a deep neural network. Weights are + masked to ensure the autoregressive property. It is possible that this + architecture is suboptimal for your task. To build alternative networks, + either change the arguments to `masked_autoregressive_default_template`, use + the `masked_dense` function to roll-out your own, or use some other + architecture, e.g., using `tf.layers`. + + Warning: no attempt is made to validate that the `shift_and_log_scale_fn` + enforces the "autoregressive property". + + Assuming `shift_and_log_scale_fn` has valid shape and autoregressive + semantics, the forward transformation is, + + ```python + def forward(x): + y = zeros_like(x) + event_size = x.shape[-1] + for _ in range(event_size): + shift, log_scale = shift_and_log_scale_fn(y) + y = x * math_ops.exp(log_scale) + shift + return y + ``` + + and the inverse transformation is, + + ```python + def inverse(y): + shift, log_scale = shift_and_log_scale_fn(y) + return (y - shift) / math_ops.exp(log_scale) + ``` + + Notice that the `inverse` does not need a for-loop. This is because in the + forward pass each calculation of `shift` and `log_scale` is based on the `y` + calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is + equivalent to the scaling used in `forward` after `event_size` passes, i.e., + the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this + also proves the transform is bijective.) + + #### Example Use + + ```python + tfd = tf.contrib.distributions + tfb = tfd.bijectors + + dims = 5 + + # A common choice for a normalizing flow is to use a Gaussian for the base + # distribution. (However, any continuous distribution would work.) E.g., + maf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( + hidden_layers=[512, 512])), + event_shape=[dims]) + + x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. + maf.log_prob(x) # Almost free; uses Bijector caching. + maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. + + # [1] also describes an "Inverse Autoregressive Flow", e.g., + iaf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( + hidden_layers=[512, 512]))), + event_shape=[dims]) + + x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. + iaf.log_prob(x) # Almost free; uses Bijector caching. + iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. + + # In many (if not most) cases the default `shift_and_log_scale_fn` will be a + # poor choice. Here's an example of using a "shift only" version and with a + # different number/depth of hidden layers. + shift_only = True + maf_no_scale_hidden2 = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + tfb.masked_autoregressive_default_template( + hidden_layers=[32], + shift_only=shift_only), + is_constant_jacobian=shift_only), + event_shape=[dims]) + ``` + + [1]: "Masked Autoregressive Flow for Density Estimation." + George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. + https://arxiv.org/abs/1705.07057 + + [2]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + """ + + def __init__(self, + shift_and_log_scale_fn, + is_constant_jacobian=False, + validate_args=False, + name=None): + """Creates the MaskedAutoregressiveFlow bijector. + + Args: + shift_and_log_scale_fn: Python `callable` which computes `shift` and + `log_scale` from both the forward domain (`x`) and the inverse domain + (`y`). Calculation must respect the "autoregressive property" (see class + docstring). Suggested default + `masked_autoregressive_default_template(hidden_layers=...)`. + Typically the function contains `tf.Variables` and is wrapped using + `tf.make_template`. Returning `None` for either (both) `shift`, + `log_scale` is equivalent to (but more efficient than) returning zero. + is_constant_jacobian: Python `bool`. Default: `False`. When `True` the + implementation assumes `log_scale` does not depend on the forward domain + (`x`) or inverse domain (`y`) values. (No validation is made; + `is_constant_jacobian=False` is always safe but possibly computationally + inefficient.) + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + name = name or "masked_autoregressive_flow" + self._shift_and_log_scale_fn = shift_and_log_scale_fn + super(MaskedAutoregressiveFlow, self).__init__( + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + + def _forward(self, x): + event_size = array_ops.shape(x)[-1] + y0 = array_ops.zeros_like(x, name="y0") + # call the template once to ensure creation + _ = self._shift_and_log_scale_fn(y0) + def _loop_body(index, y0): + """While-loop body for autoregression calculation.""" + # Set caching device to avoid re-getting the tf.Variable for every while + # loop iteration. + with variable_scope_lib.variable_scope( + variable_scope_lib.get_variable_scope()) as vs: + if vs.caching_device is None: + vs.set_caching_device(lambda op: op.device) + shift, log_scale = self._shift_and_log_scale_fn(y0) + y = x + if log_scale is not None: + y *= math_ops.exp(log_scale) + if shift is not None: + y += shift + return index + 1, y + _, y = control_flow_ops.while_loop( + cond=lambda index, _: index < event_size, + body=_loop_body, + loop_vars=[0, y0]) + return y + + def _inverse(self, y): + shift, log_scale = self._shift_and_log_scale_fn(y) + x = y + if shift is not None: + x -= shift + if log_scale is not None: + x *= math_ops.exp(-log_scale) + return x + + def _inverse_log_det_jacobian(self, y): + _, log_scale = self._shift_and_log_scale_fn(y) + if log_scale is None: + return constant_op.constant(0., dtype=y.dtype, name="ildj") + return -math_ops.reduce_sum(log_scale, axis=-1) + + +MASK_INCLUSIVE = "inclusive" +MASK_EXCLUSIVE = "exclusive" + + +def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): + """Generate the slices for building an autoregressive mask.""" + # TODO(b/67594795): Better support of dynamic shape. + slices = [] + col = 0 + d_in = n_in // num_blocks + d_out = n_out // num_blocks + row = d_out if mask_type == MASK_EXCLUSIVE else 0 + for _ in range(num_blocks): + row_slice = slice(row, None) + col_slice = slice(col, col + d_in) + slices.append([row_slice, col_slice]) + col += d_in + row += d_out + return slices + + +def _gen_mask(num_blocks, + n_in, + n_out, + mask_type=MASK_EXCLUSIVE, + dtype=dtypes.float32): + """Generate the mask for building an autoregressive dense layer.""" + # TODO(b/67594795): Better support of dynamic shape. + mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) + slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) + for [row_slice, col_slice] in slices: + mask[row_slice, col_slice] = 1 + return mask + + +def masked_dense(inputs, + units, + num_blocks=None, + exclusive=False, + kernel_initializer=None, + reuse=None, + name=None, + *args, + **kwargs): + """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. + + See [1] for detailed explanation. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + inputs: Tensor input. + units: Python `int` scalar representing the dimensionality of the output + space. + num_blocks: Python `int` scalar representing the number of blocks for the + MADE masks. + exclusive: Python `bool` scalar representing whether to zero the diagonal of + the mask, used for the first layer of a MADE. + kernel_initializer: Initializer function for the weight matrix. + If `None` (default), weights are initialized using the + `tf.glorot_random_initializer`. + reuse: Python `bool` scalar representing whether to reuse the weights of a + previous layer by the same name. + name: Python `str` used to describe ops managed by this function. + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + Output tensor. + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + # TODO(b/67594795): Better support of dynamic shape. + input_depth = inputs.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + + mask = _gen_mask(num_blocks, input_depth, units, + MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T + + if kernel_initializer is None: + kernel_initializer = init_ops.glorot_normal_initializer() + + def masked_initializer(shape, dtype=None, partition_info=None): + return mask * kernel_initializer(shape, dtype, partition_info) + + with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): + layer = layers.Dense( + units, + kernel_initializer=masked_initializer, + kernel_constraint=lambda x: mask * x, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse, + *args, + **kwargs) + return layer.apply(inputs) + + +def masked_autoregressive_default_template( + hidden_layers, + shift_only=False, + activation=nn_ops.relu, + log_scale_min_clip=-5., + log_scale_max_clip=3., + log_scale_clip_gradient=False, + name=None, + *args, + **kwargs): + """Build the MADE Model [1]. + + This will be wrapped in a make_template to ensure the variables are only + created once. It takes the input and returns the `loc` ("mu" [1]) and + `log_scale` ("alpha" [1]) from the MADE network. + + Warning: This function uses `masked_dense` to create randomly initialized + `tf.Variables`. It is presumed that these will be fit, just as you would any + other neural architecture which uses `tf.layers.dense`. + + #### About Hidden Layers: + + Each element of `hidden_layers` should be greater than the `input_depth` + (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the + neural network). This is necessary to ensure the autoregressivity property. + + #### About Clipping: + + This function also optionally clips the `log_scale` (but possibly not its + gradient). This is useful because if `log_scale` is too small/large it might + underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` + bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` + `bool` indicates whether the gradient should also be clipped. The default does + not clip the gradient; this is useful because it still provides gradient + information (for fitting) yet solves the numerical stability problem. I.e., + `log_scale_clip_gradient = False` means + `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual + `grad[clip(x)] exp(clip(x))`. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + hidden_layers: Python `list`-like of non-negative integer, scalars + indicating the number of units in each hidden layer. Default: `[512, 512]. + shift_only: Python `bool` indicating if only the `shift` term shall be + computed. Default: `False`. + activation: Activation function (callable). Explicitly setting to `None` + implies a linear activation. + log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The minimum value to clip by. Default: -5. + log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The maximum value to clip by. Default: 3. + log_scale_clip_gradient: Python `bool` indicating that the gradient of + `tf.clip_by_value` should be preserved. Default: `False`. + name: A name for ops managed by this function. Default: + "masked_autoregressive_default_template". + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + + with ops.name_scope(name, "masked_autoregressive_default_template", + values=[log_scale_min_clip, log_scale_max_clip]): + def _fn(x): + """MADE parameterized via `masked_autoregressive_default_template`.""" + # TODO(b/67594795): Better support of dynamic shape. + input_depth = x.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() + else array_ops.shape(x)) + for i, units in enumerate(hidden_layers): + x = masked_dense( + inputs=x, + units=units, + num_blocks=input_depth, + exclusive=True if i == 0 else False, + activation=activation, + *args, + **kwargs) + x = masked_dense( + inputs=x, + units=(1 if shift_only else 2) * input_depth, + num_blocks=input_depth, + activation=None, + *args, + **kwargs) + if shift_only: + x = array_ops.reshape(x, shape=input_shape) + return x, None + x = array_ops.reshape( + x, shape=array_ops.concat([input_shape, [2]], axis=0)) + shift, log_scale = array_ops.unstack(x, num=2, axis=-1) + which_clip = (math_ops.clip_by_value if log_scale_clip_gradient + else _clip_by_value_preserve_grad) + log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) + return shift, log_scale + return template_ops.make_template( + "masked_autoregressive_default_template", _fn) + + +def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): + """Clips input while leaving gradient unaltered.""" + with ops.name_scope(name, "clip_by_value_preserve_grad", + [x, clip_value_min, clip_value_max]): + clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) + return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py deleted file mode 100644 index ae142883931274b594dbbafbe86bd71e75c621bc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""MaskedAutoregressiveFlow bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.layers import core as layers -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import template as template_ops -from tensorflow.python.ops import variable_scope as variable_scope_lib -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "MaskedAutoregressiveFlow", - "masked_autoregressive_default_template", - "masked_dense", -] - - -class MaskedAutoregressiveFlow(bijector_lib.Bijector): - """Affine MaskedAutoregressiveFlow bijector for vector-valued events. - - The affine autoregressive flow [1] provides a relatively simple framework for - user-specified (deep) architectures to learn a distribution over vector-valued - events. Regarding terminology, - - "Autoregressive models decompose the joint density as a product of - conditionals, and model each conditional in turn. Normalizing flows - transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] - - In other words, the "autoregressive property" is equivalent to the - decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided - `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves - this property by zeroing out weights in its `masked_dense` layers. - - In the `tf.distributions` framework, a "normalizing flow" is implemented as a - `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" - is implemented using a `tf.while_loop` and a deep neural network (DNN) with - masked weights such that the autoregressive property is automatically met in - the `inverse`. - - A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the - (expensive) forward-mode calculation to draw samples and the (cheap) - reverse-mode calculation to compute log-probabilities. Conversely, a - `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses - the (expensive) forward-mode calculation to compute log-probabilities and the - (cheap) reverse-mode calculation to compute samples. See "Example Use" - [below] for more details. - - Given a `shift_and_log_scale_fn`, the forward and inverse transformations are - (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` - must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka - "alpha" [2]) such that each are broadcastable with the arguments to `forward` - and `inverse`, i.e., such that the calculations in `forward`, `inverse` - [below] are possible. - - For convenience, `masked_autoregressive_default_template` is offered as a - possible `shift_and_log_scale_fn` function. It implements the MADE - architecture [2]. MADE is a feed-forward network that computes a `shift` and - `log(scale)` using `masked_dense` layers in a deep neural network. Weights are - masked to ensure the autoregressive property. It is possible that this - architecture is suboptimal for your task. To build alternative networks, - either change the arguments to `masked_autoregressive_default_template`, use - the `masked_dense` function to roll-out your own, or use some other - architecture, e.g., using `tf.layers`. - - Warning: no attempt is made to validate that the `shift_and_log_scale_fn` - enforces the "autoregressive property". - - Assuming `shift_and_log_scale_fn` has valid shape and autoregressive - semantics, the forward transformation is, - - ```python - def forward(x): - y = zeros_like(x) - event_size = x.shape[-1] - for _ in range(event_size): - shift, log_scale = shift_and_log_scale_fn(y) - y = x * math_ops.exp(log_scale) + shift - return y - ``` - - and the inverse transformation is, - - ```python - def inverse(y): - shift, log_scale = shift_and_log_scale_fn(y) - return (y - shift) / math_ops.exp(log_scale) - ``` - - Notice that the `inverse` does not need a for-loop. This is because in the - forward pass each calculation of `shift` and `log_scale` is based on the `y` - calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is - equivalent to the scaling used in `forward` after `event_size` passes, i.e., - the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this - also proves the transform is bijective.) - - #### Example Use - - ```python - ds = tf.contrib.distributions - bs = tf.contrib.distributions.bijectors - - dims = 5 - - # A common choice for a normalizing flow is to use a Gaussian for the base - # distribution. (However, any continuous distribution would work.) E.g., - maf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( - hidden_layers=[512, 512])), - event_shape=[dims]) - - x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. - maf.log_prob(x) # Almost free; uses Bijector caching. - maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. - - # [1] also describes an "Inverse Autoregressive Flow", e.g., - iaf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.Invert(bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( - hidden_layers=[512, 512]))), - event_shape=[dims]) - - x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. - iaf.log_prob(x) # Almost free; uses Bijector caching. - iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. - - # In many (if not most) cases the default `shift_and_log_scale_fn` will be a - # poor choice. Here's an example of using a "shift only" version and with a - # different number/depth of hidden layers. - shift_only = True - maf_no_scale_hidden2 = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - bs.masked_autoregressive_default_template( - hidden_layers=[32], - shift_only=shift_only), - is_constant_jacobian=shift_only), - event_shape=[dims]) - ``` - - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 - - [2]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - """ - - def __init__(self, - shift_and_log_scale_fn, - is_constant_jacobian=False, - validate_args=False, - name=None): - """Creates the MaskedAutoregressiveFlow bijector. - - Args: - shift_and_log_scale_fn: Python `callable` which computes `shift` and - `log_scale` from both the forward domain (`x`) and the inverse domain - (`y`). Calculation must respect the "autoregressive property" (see class - docstring). Suggested default - `masked_autoregressive_default_template(hidden_layers=...)`. - Typically the function contains `tf.Variables` and is wrapped using - `tf.make_template`. Returning `None` for either (both) `shift`, - `log_scale` is equivalent to (but more efficient than) returning zero. - is_constant_jacobian: Python `bool`. Default: `False`. When `True` the - implementation assumes `log_scale` does not depend on the forward domain - (`x`) or inverse domain (`y`) values. (No validation is made; - `is_constant_jacobian=False` is always safe but possibly computationally - inefficient.) - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - name = name or "masked_autoregressive_flow" - self._shift_and_log_scale_fn = shift_and_log_scale_fn - super(MaskedAutoregressiveFlow, self).__init__( - is_constant_jacobian=is_constant_jacobian, - validate_args=validate_args, - name=name) - - def _forward(self, x): - event_size = array_ops.shape(x)[-1] - def _loop_body(index, y0): - """While-loop body for autoregression calculation.""" - # Set caching device to avoid re-getting the tf.Variable for every while - # loop iteration. - with variable_scope_lib.variable_scope( - variable_scope_lib.get_variable_scope()) as vs: - if vs.caching_device is None: - vs.set_caching_device(lambda op: op.device) - shift, log_scale = self._shift_and_log_scale_fn(y0) - y = x - if log_scale is not None: - y *= math_ops.exp(log_scale) - if shift is not None: - y += shift - return index + 1, y - _, y = control_flow_ops.while_loop( - cond=lambda index, _: index < event_size, - body=_loop_body, - loop_vars=[0, array_ops.zeros_like(x, name="y0")]) - return y - - def _inverse(self, y): - shift, log_scale = self._shift_and_log_scale_fn(y) - x = y - if shift is not None: - x -= shift - if log_scale is not None: - x *= math_ops.exp(-log_scale) - return x - - def _inverse_log_det_jacobian(self, y): - _, log_scale = self._shift_and_log_scale_fn(y) - if log_scale is None: - return constant_op.constant(0., dtype=y.dtype, name="ildj") - return -math_ops.reduce_sum(log_scale, axis=-1) - - -MASK_INCLUSIVE = "inclusive" -MASK_EXCLUSIVE = "exclusive" - - -def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): - """Generate the slices for building an autoregressive mask.""" - # TODO(b/67594795): Better support of dynamic shape. - slices = [] - col = 0 - d_in = n_in // num_blocks - d_out = n_out // num_blocks - row = d_out if mask_type == MASK_EXCLUSIVE else 0 - for _ in range(num_blocks): - row_slice = slice(row, None) - col_slice = slice(col, col + d_in) - slices.append([row_slice, col_slice]) - col += d_in - row += d_out - return slices - - -def _gen_mask(num_blocks, - n_in, - n_out, - mask_type=MASK_EXCLUSIVE, - dtype=dtypes.float32): - """Generate the mask for building an autoregressive dense layer.""" - # TODO(b/67594795): Better support of dynamic shape. - mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) - slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) - for [row_slice, col_slice] in slices: - mask[row_slice, col_slice] = 1 - return mask - - -def masked_dense(inputs, - units, - num_blocks=None, - exclusive=False, - kernel_initializer=None, - reuse=None, - name=None, - *args, - **kwargs): - """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. - - See [1] for detailed explanation. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: - inputs: Tensor input. - units: Python `int` scalar representing the dimensionality of the output - space. - num_blocks: Python `int` scalar representing the number of blocks for the - MADE masks. - exclusive: Python `bool` scalar representing whether to zero the diagonal of - the mask, used for the first layer of a MADE. - kernel_initializer: Initializer function for the weight matrix. - If `None` (default), weights are initialized using the - `tf.glorot_random_initializer`. - reuse: Python `bool` scalar representing whether to reuse the weights of a - previous layer by the same name. - name: Python `str` used to describe ops managed by this function. - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. - - Returns: - Output tensor. - - Raises: - NotImplementedError: if rightmost dimension of `inputs` is unknown prior to - graph execution. - """ - # TODO(b/67594795): Better support of dynamic shape. - input_depth = inputs.shape.with_rank_at_least(1)[-1].value - if input_depth is None: - raise NotImplementedError( - "Rightmost dimension must be known prior to graph execution.") - - mask = _gen_mask(num_blocks, input_depth, units, - MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T - - if kernel_initializer is None: - kernel_initializer = init_ops.glorot_normal_initializer() - - def masked_initializer(shape, dtype=None, partition_info=None): - return mask * kernel_initializer(shape, dtype, partition_info) - - with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): - layer = layers.Dense( - units, - kernel_initializer=masked_initializer, - kernel_constraint=lambda x: mask * x, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse, - *args, - **kwargs) - return layer.apply(inputs) - - -def masked_autoregressive_default_template( - hidden_layers, - shift_only=False, - activation=nn_ops.relu, - log_scale_min_clip=-5., - log_scale_max_clip=3., - log_scale_clip_gradient=False, - name=None, - *args, - **kwargs): - """Build the MADE Model [1]. - - This will be wrapped in a make_template to ensure the variables are only - created once. It takes the input and returns the `loc` ("mu" [1]) and - `log_scale` ("alpha" [1]) from the MADE network. - - Warning: This function uses `masked_dense` to create randomly initialized - `tf.Variables`. It is presumed that these will be fit, just as you would any - other neural architecture which uses `tf.layers.dense`. - - #### About Hidden Layers: - - Each element of `hidden_layers` should be greater than the `input_depth` - (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the - neural network). This is necessary to ensure the autoregressivity property. - - #### About Clipping: - - This function also optionally clips the `log_scale` (but possibly not its - gradient). This is useful because if `log_scale` is too small/large it might - underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` - bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` - `bool` indicates whether the gradient should also be clipped. The default does - not clip the gradient; this is useful because it still provides gradient - information (for fitting) yet solves the numerical stability problem. I.e., - `log_scale_clip_gradient = False` means - `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual - `grad[clip(x)] exp(clip(x))`. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: - hidden_layers: Python `list`-like of non-negative integer, scalars - indicating the number of units in each hidden layer. Default: `[512, 512]. - shift_only: Python `bool` indicating if only the `shift` term shall be - computed. Default: `False`. - activation: Activation function (callable). Explicitly setting to `None` - implies a linear activation. - log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the - same shape as `log_scale`. The minimum value to clip by. Default: -5. - log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the - same shape as `log_scale`. The maximum value to clip by. Default: 3. - log_scale_clip_gradient: Python `bool` indicating that the gradient of - `tf.clip_by_value` should be preserved. Default: `False`. - name: A name for ops managed by this function. Default: - "masked_autoregressive_default_template". - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. - - Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). - - Raises: - NotImplementedError: if rightmost dimension of `inputs` is unknown prior to - graph execution. - """ - - with ops.name_scope(name, "masked_autoregressive_default_template", - values=[log_scale_min_clip, log_scale_max_clip]): - def _fn(x): - """MADE parameterized via `masked_autoregressive_default_template`.""" - # TODO(b/67594795): Better support of dynamic shape. - input_depth = x.shape.with_rank_at_least(1)[-1].value - if input_depth is None: - raise NotImplementedError( - "Rightmost dimension must be known prior to graph execution.") - input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() - else array_ops.shape(x)) - for i, units in enumerate(hidden_layers): - x = masked_dense( - inputs=x, - units=units, - num_blocks=input_depth, - exclusive=True if i == 0 else False, - activation=activation, - *args, - **kwargs) - x = masked_dense( - inputs=x, - units=(1 if shift_only else 2) * input_depth, - num_blocks=input_depth, - activation=None, - *args, - **kwargs) - if shift_only: - x = array_ops.reshape(x, shape=input_shape) - return x, None - x = array_ops.reshape( - x, shape=array_ops.concat([input_shape, [2]], axis=0)) - shift, log_scale = array_ops.unstack(x, num=2, axis=-1) - which_clip = (math_ops.clip_by_value if log_scale_clip_gradient - else _clip_by_value_preserve_grad) - log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) - return shift, log_scale - return template_ops.make_template( - "masked_autoregressive_default_template", _fn) - - -def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): - """Clips input while leaving gradient unaltered.""" - with ops.name_scope(name, "clip_by_value_preserve_grad", - [x, clip_value_min, clip_value_max]): - clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) - return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index a187ce22d686ee1203802ae2bfe64b0e1a3ea850..8654cc39d0c41ec4f1b85cd5fc4366ceaf4b224d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -12,18 +12,127 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Permute bijector.""" +"""Permutation bijectors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.permute_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["Permute"] +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Permute", +] + + +class Permute(bijector_lib.Bijector): + """Permutes the rightmost dimension of a `Tensor`. + + ```python + tfd = tf.contrib.distributions + + reverse = tfd.bijectors.Permute(permutation=[2, 1, 0]) + + reverse.forward([-1., 0., 1.]) + # ==> [1., 0., -1] + + reverse.inverse([1., 0., -1]) + # ==> [-1., 0., 1.] + + reverse.forward_log_det_jacobian(any_value) + # ==> 0. + + reverse.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + Warning: `tf.estimator` may repeatedly build the graph thus + `Permute(np.random.permutation(event_size)).astype("int32"))` is not a + reliable parameterization (nor would it be even if using `tf.constant`). A + safe alternative is to use `tf.get_variable` to achieve "init once" behavior, + i.e., + + ```python + def init_once(x, name): + return tf.get_variable(name, initializer=x, trainable=False) + + Permute(permutation=init_once( + np.random.permutation(event_size).astype("int32"), + name="permutation")) + ``` + + """ + + def __init__(self, permutation, validate_args=False, name=None): + """Creates the `Permute` bijector. + + Args: + permutation: An `int`-like vector-shaped `Tensor` representing the + permutation to apply to the rightmost dimension of the transformed + `Tensor`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if `not permutation.dtype.is_integer`. + ValueError: if `permutation` does not contain exactly one of each of + `{0, 1, ..., d}`. + """ + with ops.name_scope(name, "permute", values=[permutation]): + permutation = ops.convert_to_tensor( + permutation, + name="permutation") + if not permutation.dtype.is_integer: + raise TypeError("permutation.dtype ({}) should be `int`-like.".format( + permutation.dtype.name)) + p = tensor_util.constant_value(permutation) + if p is not None: + if set(p) != set(np.arange(p.size)): + raise ValueError("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.") + elif validate_args: + p, _ = nn_ops.top_k(-permutation, + k=array_ops.shape(permutation)[-1], + sorted=True) + permutation = control_flow_ops.with_dependencies([ + check_ops.assert_equal( + -p, math_ops.range(array_ops.size(p)), + message=("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.")), + ], permutation) + self._permutation = permutation + super(Permute, self).__init__( + is_constant_jacobian=True, + validate_args=validate_args, + name=name or "permute") + + @property + def permutation(self): + return self._permutation + + def _forward(self, x): + return array_ops.gather(x, self.permutation, axis=-1) + + def _inverse(self, y): + return array_ops.gather( + y, + array_ops.invert_permutation(self.permutation), + axis=-1) + + def _inverse_log_det_jacobian(self, y): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py deleted file mode 100644 index b1d8f2f41b28a88208a19824377f93882b767f03..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Permutation bijectors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "Permute", -] - - -class Permute(bijector_lib.Bijector): - """Permutes the rightmost dimension of a `Tensor`. - - ```python - bs = tf.contrib.distributions.bijectors - - reverse = bs.Permute(permutation=[2, 1, 0]) - - reverse.forward([-1., 0., 1.]) - # ==> [1., 0., -1] - - reverse.inverse([1., 0., -1]) - # ==> [-1., 0., 1.] - - reverse.forward_log_det_jacobian(any_value) - # ==> 0. - - reverse.inverse_log_det_jacobian(any_value) - # ==> 0. - ``` - - Warning: `tf.estimator` may repeatedly build the graph thus - `Permute(np.random.permutation(event_size)).astype("int32"))` is not a - reliable parameterization (nor would it be even if using `tf.constant`). A - safe alternative is to use `tf.get_variable` to achieve "init once" behavior, - i.e., - - ```python - def init_once(x, name): - return tf.get_variable(name, initializer=x, trainable=False) - - Permute(permutation=init_once( - np.random.permutation(event_size).astype("int32"), - name="permutation")) - ``` - - """ - - def __init__(self, permutation, validate_args=False, name=None): - """Creates the `Permute` bijector. - - Args: - permutation: An `int`-like vector-shaped `Tensor` representing the - permutation to apply to the rightmost dimension of the transformed - `Tensor`. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - - Raises: - TypeError: if `not permutation.dtype.is_integer`. - ValueError: if `permutation` does not contain exactly one of each of - `{0, 1, ..., d}`. - """ - with ops.name_scope(name, "permute", values=[permutation]): - permutation = ops.convert_to_tensor( - permutation, - name="permutation") - if not permutation.dtype.is_integer: - raise TypeError("permutation.dtype ({}) should be `int`-like.".format( - permutation.dtype.name)) - p = tensor_util.constant_value(permutation) - if p is not None: - if set(p) != set(np.arange(p.size)): - raise ValueError("Permutation over `d` must contain exactly one of " - "each of `{0, 1, ..., d}`.") - elif validate_args: - p, _ = nn_ops.top_k(-permutation, - k=array_ops.shape(permutation)[-1], - sorted=True) - permutation = control_flow_ops.with_dependencies([ - check_ops.assert_equal( - -p, math_ops.range(array_ops.size(p)), - message=("Permutation over `d` must contain exactly one of " - "each of `{0, 1, ..., d}`.")), - ], permutation) - self._permutation = permutation - super(Permute, self).__init__( - is_constant_jacobian=True, - validate_args=validate_args, - name=name or "permute") - - @property - def permutation(self): - return self._permutation - - def _forward(self, x): - return array_ops.gather(x, self.permutation, axis=-1) - - def _inverse(self, y): - return array_ops.gather( - y, - array_ops.invert_permutation(self.permutation), - axis=-1) - - def _inverse_log_det_jacobian(self, y): - return constant_op.constant(0., dtype=y.dtype) - - def _forward_log_det_jacobian(self, x): - return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py index a83199549cd16101ab7b39b43d19a17bc66f03df..c37db61720d10949f294ff7b2e9778ba6efa57f0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py @@ -18,12 +18,110 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.power_transform_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["PowerTransform"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "PowerTransform", +] + + +class PowerTransform(bijector.Bijector): + """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. + + The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps + inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` + of this bijector. + + This bijector is equivalent to the `Exp` bijector when `c=0`. + """ + + def __init__(self, + power=0., + event_ndims=0, + validate_args=False, + name="power_transform"): + """Instantiates the `PowerTransform` bijector. + + Args: + power: Python `float` scalar indicating the transform power, i.e., + `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `power < 0` or is not known statically. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[power]): + power = tensor_util.constant_value( + ops.convert_to_tensor(power, name="power")) + if power is None or power < 0: + raise ValueError("`power` must be a non-negative TF constant.") + self._power = power + super(PowerTransform, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + @property + def power(self): + """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`.""" + return self._power + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + if self.power == 0.: + return math_ops.exp(x) + # If large x accuracy is an issue, consider using: + # (1. + x * self.power)**(1. / self.power) when x >> 1. + return math_ops.exp(math_ops.log1p(x * self.power) / self.power) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + if self.power == 0.: + return math_ops.log(y) + # If large y accuracy is an issue, consider using: + # (y**self.power - 1.) / self.power when y >> 1. + return math_ops.expm1(math_ops.log(y) * self.power) / self.power + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return (self.power - 1.) * math_ops.reduce_sum( + math_ops.log(y), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid_x(x) + event_dims = self._event_dims_tensor(x) + if self.power == 0.: + return math_ops.reduce_sum(x, axis=event_dims) + return (1. / self.power - 1.) * math_ops.reduce_sum( + math_ops.log1p(x * self.power), + axis=event_dims) + + def _maybe_assert_valid_x(self, x): + if not self.validate_args or self.power == 0.: + return x + is_valid = check_ops.assert_non_negative( + 1. + self.power * x, + message="Forward transformation input must be at least {}.".format( + -1. / self.power)) + return control_flow_ops.with_dependencies([is_valid], x) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_valid = check_ops.assert_positive( + y, message="Inverse transformation input must be greater than 0.") + return control_flow_ops.with_dependencies([is_valid], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py deleted file mode 100644 index c37db61720d10949f294ff7b2e9778ba6efa57f0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""PowerTransform bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "PowerTransform", -] - - -class PowerTransform(bijector.Bijector): - """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. - - The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps - inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` - of this bijector. - - This bijector is equivalent to the `Exp` bijector when `c=0`. - """ - - def __init__(self, - power=0., - event_ndims=0, - validate_args=False, - name="power_transform"): - """Instantiates the `PowerTransform` bijector. - - Args: - power: Python `float` scalar indicating the transform power, i.e., - `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `power < 0` or is not known statically. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[power]): - power = tensor_util.constant_value( - ops.convert_to_tensor(power, name="power")) - if power is None or power < 0: - raise ValueError("`power` must be a non-negative TF constant.") - self._power = power - super(PowerTransform, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - @property - def power(self): - """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`.""" - return self._power - - def _forward(self, x): - x = self._maybe_assert_valid_x(x) - if self.power == 0.: - return math_ops.exp(x) - # If large x accuracy is an issue, consider using: - # (1. + x * self.power)**(1. / self.power) when x >> 1. - return math_ops.exp(math_ops.log1p(x * self.power) / self.power) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - if self.power == 0.: - return math_ops.log(y) - # If large y accuracy is an issue, consider using: - # (y**self.power - 1.) / self.power when y >> 1. - return math_ops.expm1(math_ops.log(y) * self.power) / self.power - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return (self.power - 1.) * math_ops.reduce_sum( - math_ops.log(y), axis=event_dims) - - def _forward_log_det_jacobian(self, x): - x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - if self.power == 0.: - return math_ops.reduce_sum(x, axis=event_dims) - return (1. / self.power - 1.) * math_ops.reduce_sum( - math_ops.log1p(x * self.power), - axis=event_dims) - - def _maybe_assert_valid_x(self, x): - if not self.validate_args or self.power == 0.: - return x - is_valid = check_ops.assert_non_negative( - 1. + self.power * x, - message="Forward transformation input must be at least {}.".format( - -1. / self.power)) - return control_flow_ops.with_dependencies([is_valid], x) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_valid = check_ops.assert_positive( - y, message="Inverse transformation input must be greater than 0.") - return control_flow_ops.with_dependencies([is_valid], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 8997f7ab6929745275edb38712a5bbb0a9b25ddb..55eca063126797d577653f0d6bcdfddf8192bdb5 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -12,18 +12,303 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Reshape bijector.""" +"""Reshape bijectors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.reshape_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["Reshape"] +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Reshape", +] + + +def _static_ndims_from_shape(shape): + return shape.shape.with_rank_at_least(1)[0].value + + +def _ndims_from_shape(shape): + return array_ops.shape(shape)[0] + + +class Reshape(bijector_lib.Bijector): + """Reshapes the `event_shape` of a `Tensor`. + + The semantics generally follow that of `tf.reshape()`, with + a few differences: + + * The user must provide both the input and output shape, so that + the transformation can be inverted. If an input shape is not + specified, the default assumes a vector-shaped input, i.e., + event_shape_in = (-1,). + * The `Reshape` bijector automatically broadcasts over the leftmost + dimensions of its input (`sample_shape` and `batch_shape`); only + the rightmost `event_ndims_in` dimensions are reshaped. The + number of dimensions to reshape is inferred from the provided + `event_shape_in` (`event_ndims_in = len(event_shape_in)`). + + Example usage: + ```python + + tfd = tf.contrib.distributions + + r = tfd.bijectors.Reshape(event_shape_out=[1, -1]) + + r.forward([3., 4.]) # shape [2] + # ==> [[3., 4.]] # shape [1, 2] + + r.forward([[1., 2.], [3., 4.]]) # shape [2, 2] + # ==> [[[1., 2.]], + # [[3., 4.]]] # shape [2, 1, 2] + + r.inverse([[3., 4.]]) # shape [1,2] + # ==> [3., 4.] # shape [2] + + r.forward_log_det_jacobian(any_value) + # ==> 0. + + r.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + """ + + def __init__(self, event_shape_out, event_shape_in=(-1,), + validate_args=False, name=None): + """Creates a `Reshape` bijector. + + Args: + event_shape_out: An `int`-like vector-shaped `Tensor` + representing the event shape of the transformed output. + event_shape_in: An optional `int`-like vector-shape `Tensor` + representing the event shape of the input. This is required in + order to define inverse operations; the default of (-1,) + assumes a vector-shaped input. + validate_args: Python `bool` indicating whether arguments should + be checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if either `event_shape_in` or `event_shape_out` has + non-integer `dtype`. + ValueError: if either of `event_shape_in` or `event_shape_out` + has non-vector shape (`rank > 1`), or if their sizes do not + match. + """ + with ops.name_scope(name, "reshape", + values=[event_shape_out, event_shape_in]): + + event_shape_out = ops.convert_to_tensor(event_shape_out, + name="event_shape_out", + preferred_dtype=dtypes.int32) + event_shape_in = ops.convert_to_tensor(event_shape_in, + name="event_shape_in", + preferred_dtype=dtypes.int32) + + assertions = [] + assertions.extend(self._maybe_check_valid_shape( + event_shape_out, validate_args)) + assertions.extend(self._maybe_check_valid_shape( + event_shape_in, validate_args)) + + self._assertions = assertions + self._event_shape_in = event_shape_in + self._event_shape_out = event_shape_out + + super(Reshape, self).__init__(is_constant_jacobian=True, + validate_args=validate_args, + name=name or "reshape") + + def _maybe_check_valid_shape(self, shape, validate_args): + """Check that a shape Tensor is int-type and otherwise sane.""" + if not shape.dtype.is_integer: + raise TypeError("{} dtype ({}) should be `int`-like.".format( + shape.op.name, shape.dtype.name)) + + assertions = [] + + ndims = array_ops.rank(shape) + ndims_ = tensor_util.constant_value(ndims) + if ndims_ is not None and ndims_ > 1: + raise ValueError("`{}` rank ({}) should be <= 1.".format( + shape.op.name, ndims_)) + elif validate_args: + assertions.append(check_ops.assert_less_equal( + ndims, 1, message="`{}` rank should be <= 1.".format(shape.op.name))) + + shape_ = tensor_util.constant_value_as_shape(shape) + if shape_.is_fully_defined(): + es = np.int32(shape_.as_list()) + if sum(es == -1) > 1: + raise ValueError( + "`{}` must have at most one `-1` (given {})" + .format(shape.op.name, es)) + if np.any(es < -1): + raise ValueError( + "`{}` elements must be either positive integers or `-1`" + "(given {})." + .format(shape.op.name, es)) + elif validate_args: + assertions.extend([ + check_ops.assert_less_equal( + math_ops.reduce_sum( + math_ops.cast(math_ops.equal(shape, -1), dtypes.int32)), + 1, + message="`{}` elements must have at most one `-1`." + .format(shape.op.name)), + check_ops.assert_greater_equal( + shape, -1, + message="`{}` elements must be either positive integers or `-1`." + .format(shape.op.name)), + ]) + return assertions + + def _reshape_helper(self, x, event_shape_in, event_shape_out): + """Reshape only the event_shape of an input `Tensor`.""" + + event_ndims_in_ = _static_ndims_from_shape(event_shape_in) + event_ndims_in = _ndims_from_shape(event_shape_in) + x_ndims_, x_ndims = x.shape.ndims, array_ops.rank(x) + + assertions = [] + + # Ensure x.event_shape is compatible with event_shape_in. + if (event_ndims_in_ is not None + and x_ndims_ is not None + and x.shape.with_rank_at_least(event_ndims_in_)[ + x_ndims_-event_ndims_in_:].is_fully_defined()): + x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking + np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 + else: + x_event_shape_, x_event_shape = ( + None, array_ops.shape(x)[x_ndims-event_ndims_in:]) + + event_shape_in_ = tensor_util.constant_value(event_shape_in) + + if x_event_shape_ is not None and event_shape_in_ is not None: + # Compare the shape dimensions that are fully specified in the + # input (i.e., for which event_shape_in is not -1). If x_event_shape + # matches along all of these dimensions, it is compatible with + # the desired input shape and any further mismatches (i.e., + # imcompatibility with the desired *output* shape) will be + # caught inside of array_ops.reshape() below. + x_event_shape_specified_ = x_event_shape_[event_shape_in_ >= 0] + event_shape_in_specified_ = event_shape_in_[event_shape_in_ >= 0] + if not np.equal(x_event_shape_specified_, + event_shape_in_specified_).all(): + raise ValueError( + "Input `event_shape` does not match `event_shape_in` ({} vs {}).". + format(x_event_shape_, event_shape_in_)) + elif self.validate_args: + # Similarly to the static case, we compare the shape dimensions + # that are fully specified in the input. We extract these + # dimensions using boolean_mask(), which requires that the mask + # have known ndims. We can assume that shape Tensors always have + # ndims==1 (this assumption is verified inside of + # _maybe_check_valid_shape), so the reshape operation is just a + # no-op that formally encodes this fact to make boolean_mask() + # happy. + event_shape_mask = array_ops.reshape(event_shape_in >= 0, [-1]) + x_event_shape_specified = array_ops.boolean_mask(x_event_shape, + event_shape_mask) + event_shape_in_specified = array_ops.boolean_mask(event_shape_in, + event_shape_mask) + assertions.append(check_ops.assert_equal( + x_event_shape_specified, event_shape_in_specified, + message="Input `event_shape` does not match `event_shape_in`.")) + + if assertions: + x = control_flow_ops.with_dependencies(assertions, x) + + # get the parts of shape(x) that will not change + sample_and_batch_shape = array_ops.shape(x) + + ndims = (x.shape.ndims if x.shape.ndims is not None + else array_ops.rank(x)) + sample_and_batch_shape = sample_and_batch_shape[ + :(ndims - math_ops.abs(event_ndims_in))] + + if (event_ndims_in_ is not None + and x_ndims_ is not None + and event_ndims_in_ == x_ndims_): + # Hack to allow forward/inverse_event_shape to do shape + # inference by calling this helper method with a dummy Tensor of + # shape event_shape_in. In this special case, + # sample_and_batch_shape will be empty so we can preserve static + # shape information by avoiding the concat operation below + # (which would be a no-op). + new_shape = event_shape_out + else: + new_shape = array_ops.concat( + [sample_and_batch_shape, event_shape_out], axis=0) + + return array_ops.reshape(x, new_shape) + + def _forward(self, x): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(x, + self._event_shape_in, + self._event_shape_out) + + def _inverse(self, y): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(y, + self._event_shape_out, + self._event_shape_in) + + def _inverse_log_det_jacobian(self, y): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=x.dtype) + + def _forward_event_shape(self, input_shape): + # NOTE: this method and the other *_event_shape* methods + # compute shape by explicit transformation of a dummy + # variable. This approach is not generally recommended because it + # bloats the graph and could in general trigger side effects. + # + # In this particular case of the Reshape bijector, the + # forward and inverse transforms have no side effects, and we + # believe the reduction in code complexity from delegating the + # heavy lifting to tf.reshape() is worth the added graph ops. + # However, you should think hard before implementing this approach + # in other Bijectors; it is strongly preferred to compute + # shapes explicitly whenever it's feasible to do so. + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return dummy_reshaped.shape + + def _inverse_event_shape(self, output_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return dummy_reshaped.shape + + def _forward_event_shape_tensor(self, input_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return array_ops.shape(dummy_reshaped) + + def _inverse_event_shape_tensor(self, output_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return array_ops.shape(dummy_reshaped) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py deleted file mode 100644 index 93682639aa3be3b8f59a369dedb6ee773c468130..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py +++ /dev/null @@ -1,297 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Reshape bijectors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "Reshape", -] - - -class Reshape(bijector_lib.Bijector): - """Reshapes the `event_shape` of a `Tensor`. - - The semantics generally follow that of `tf.reshape()`, with - a few differences: - * The user must provide both the input and output shape, so that - the transformation can be inverted. - * The `Reshape` bijector automatically broadcasts over the leftmost - dimensions of its input (`sample_shape` and `batch_shape`); only - the rightmost `event_ndims_in` dimensions are reshaped. The - number of dimensions to reshape is inferred from the provided - `event_shape_in` (`event_ndims_in = len(event_shape_in)`). - * The `Reshape` bijector does not currently support - partially-specified shapes, i.e., those with a dimension - implicitly specified by `-1`. - - Example usage: - ```python - - bs = tf.contrib.distributions.bijectors - - reverse = bs.Reshape(event_shape_out=[1,2], - event_shape_in=[2,]) - - reverse.forward([1., 2.]) # shape [2,] - # ==> [[1., 2.]] # shape [1,2] - - reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2] - # ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2] - - reverse.inverse([[1., 2.]]) # shape [1,2] - # ==> [1., 2.] # shape [2,] - - reverse.forward_log_det_jacobian(any_value) - # ==> 0. - - reverse.inverse_log_det_jacobian(any_value) - # ==> 0. - ``` - - """ - - def __init__(self, event_shape_out, event_shape_in, - validate_args=False, name=None): - """Creates a `Reshape` bijector. - - Args: - event_shape_out: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - transformed output. - event_shape_in: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - input. - validate_args: Python `bool` indicating whether arguments should - be checked for correctness. - name: Python `str`, name given to ops managed by this object. - - Raises: - TypeError: if either `event_shape_in` or `event_shape_out` has - non-vector shape (`rank > 1`), or non-integer `dtype`. - ValueError: if either `event_shape_in` or `event_shape_out` - contains non-positive entries, or if their sizes do not match - (`prod(event_shape_in)` != `prod(event_shape_out)`), or if - their dimensionality(s) cannot be statically inferred. - """ - with ops.name_scope(name, "reshape", - values=[event_shape_out, event_shape_in]): - - event_shape_out = ops.convert_to_tensor(event_shape_out, - name="event_shape_out", - preferred_dtype=dtypes.int32) - event_shape_in = ops.convert_to_tensor(event_shape_in, - name="event_shape_in", - preferred_dtype=dtypes.int32) - - # check that input shapes are positive integers - assertions = [] - assertions += self._maybe_check_valid_shape( - event_shape_out, "event_shape_out", - validate_args=validate_args) - assertions += self._maybe_check_valid_shape( - event_shape_in, "event_shape_in", validate_args=validate_args) - - # check that prod(event_shape_in) = prod(event_shape_out) - assertions += self._maybe_check_matching_sizes( - event_shape_in, event_shape_out, validate_args=validate_args) - - self._assertions = assertions - self._event_shape_in = event_shape_in - self._event_shape_out = event_shape_out - self._event_shape_in_static = tensor_util.constant_value_as_shape( - event_shape_in) - self._event_shape_out_static = tensor_util.constant_value_as_shape( - event_shape_out) - - super(Reshape, self).__init__(is_constant_jacobian=True, - validate_args=validate_args, - name=name or "reshape") - - def _maybe_check_valid_shape(self, shape_tensor, label, - validate_args=False): - """Check that a shape Tensor is int-type and positive.""" - - assertions = [] - - if not shape_tensor.dtype.is_integer: - raise TypeError("{} dtype ({}) should be `int`-like.".format( - label, shape_tensor.dtype.name)) - - shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor)) - if shape_rank is not None and shape_rank > 1: - raise ValueError("{} rank should be <= 1.".format(label)) - - s = tensor_util.constant_value(shape_tensor) - if s is not None: - if (s <= 0).any(): - raise ValueError("{} entries must be positive, but found {}".format( - label, s)) - elif validate_args: - assertions.append(check_ops.assert_positive( - shape_tensor, message="{} entries must be positive".format(label))) - - return assertions - - def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out, - validate_args=False): - """Check that prod(event_shape_in)==prod(event_shape_out).""" - - def _get_size_from_shape(shape): - """Computes size from a shape `Tensor`, statically if possible.""" - s = tensor_util.constant_value(shape) - if s is not None: - return [np.int32(np.prod(s))]*2 - return None, math_ops.reduce_prod(shape, name="size") - - # Ensure `event_shape_in` is compatible with `event_shape_out`. - event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_in) - event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_out) - - assertions = [] - if event_size_in_ is not None and event_size_out_ is not None: - if event_size_in_ != event_size_out_: - raise ValueError( - "Input `event_size` ({}) does not match output `event_size` ({}).". - format(event_size_in, event_size_out_)) - elif validate_args: - assertions.append(check_ops.assert_equal( - event_size_in, event_size_out, - message="Input/output `event_size`s do not match.")) - - return assertions - - def _reshape_helper(self, x, event_shape_in, event_shape_out): - """Reshape only the event_shape of an input `Tensor`.""" - - def _get_rank_from_shape(shape): - """Computes rank from a shape `Tensor`, statically if possible.""" - # Uses fact that rank is "shape of shape". - ndims = shape.shape.with_rank_at_least(1)[0].value - if ndims is not None: - return ndims, ndims - return None, array_ops.shape(shape)[0] - - event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in) - - assertions = [] - # Ensure x.event_shape is compatible with event_shape_in. - if x.shape.ndims is not None: - x_ndims_, x_ndims = [x.shape.ndims]*2 - else: - x_ndims_, x_ndims = None, array_ops.rank(x) - - if (event_ndims_in_ is not None - and x_ndims_ is not None - and x.shape.with_rank_at_least(event_ndims_in_)[ - x_ndims_-event_ndims_in_:].is_fully_defined()): - x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking - np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 - else: - x_event_shape_, x_event_shape = ( - None, array_ops.shape(x)[x_ndims-event_ndims_in:]) - - event_shape_in_ = tensor_util.constant_value(event_shape_in) - - if x_event_shape_ is not None and event_shape_in_ is not None: - if not np.equal(x_event_shape_, event_shape_in_).all(): - raise ValueError( - "Input `event_shape` ({}) does not match `event_shape_in` ({}).". - format(x_event_shape_, event_shape_in_)) - elif self.validate_args: - assertions.append(check_ops.assert_equal( - x_event_shape, event_shape_in, - message="Input `event_shape` does not match `event_shape_in`.")) - - if assertions: - x = control_flow_ops.with_dependencies(assertions, x) - - # get the parts of shape(x) that will not change - sample_and_batch_shape = array_ops.shape(x) - - ndims = (x.shape.ndims if x.shape.ndims is not None - else array_ops.rank(x)) - sample_and_batch_shape = sample_and_batch_shape[ - :(ndims - math_ops.abs(event_ndims_in))] - - new_shape = array_ops.concat( - [sample_and_batch_shape, event_shape_out], axis=0) - - return array_ops.reshape(x, new_shape) - - def _forward(self, x): - with ops.control_dependencies(self._assertions): - return self._reshape_helper(x, - self._event_shape_in, - self._event_shape_out) - - def _inverse(self, y): - with ops.control_dependencies(self._assertions): - return self._reshape_helper(y, - self._event_shape_out, - self._event_shape_in) - - def _inverse_log_det_jacobian(self, y): - with ops.control_dependencies(self._assertions): - return constant_op.constant(0., dtype=y.dtype) - - def _forward_log_det_jacobian(self, x): - with ops.control_dependencies(self._assertions): - return constant_op.constant(0., dtype=x.dtype) - - def _forward_event_shape(self, input_shape): - self._event_shape_in_static.assert_is_compatible_with(input_shape) - return self._event_shape_out_static - - def _inverse_event_shape(self, output_shape): - self._event_shape_out_static.assert_is_compatible_with(output_shape) - return self._event_shape_in_static - - def _forward_event_shape_tensor(self, input_shape): - input_assertions = self._maybe_check_valid_shape( - input_shape, "input event shape", validate_args=self.validate_args) - input_assertions += self._maybe_check_matching_sizes( - input_shape, self._event_shape_out, - validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - input_assertions + self._assertions, self._event_shape_out) - - def _inverse_event_shape_tensor(self, output_shape): - - output_assertions = self._maybe_check_valid_shape( - output_shape, "output event shape", validate_args=self.validate_args) - output_assertions += self._maybe_check_matching_sizes( - output_shape, self._event_shape_in, validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - output_assertions + self._assertions, self._event_shape_in) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index c20e76c0b7367369865faf973377201c8b8b17e6..a640dfe7dfbcce96261589c7fc49107deaefdd54 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -18,12 +18,31 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Sigmoid"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Sigmoid", +] + + +class Sigmoid(bijector.Bijector): + """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" + + def __init__(self, validate_args=False, name="sigmoid"): + super(Sigmoid, self).__init__( + event_ndims=0, validate_args=validate_args, name=name) + + def _forward(self, x): + return math_ops.sigmoid(x) + + def _inverse(self, y): + return math_ops.log(y) - math_ops.log1p(-y) + + def _inverse_log_det_jacobian(self, y): + return -math_ops.log(y) - math_ops.log1p(-y) + + def _forward_log_det_jacobian(self, x): + return -nn_ops.softplus(-x) - nn_ops.softplus(x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py index 448125230d24066697624bce03fed71a2c2f00b1..223bc9d042c69be05b0e578835a31ed6e83c0c97 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py @@ -18,12 +18,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered -_allowed_symbols = ["SigmoidCentered"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "SigmoidCentered", +] + + +class SigmoidCentered(softmax_centered.SoftmaxCentered): + """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)). + + Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`. + + See `bijector.SoftmaxCentered` for more details. + """ + + def __init__(self, validate_args=False, name="sigmoid_centered"): + super(SigmoidCentered, self).__init__( + event_ndims=0, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index b3cf03c24612f5c618c71c0a8615f272acdf2d10..3a75e4ae9495793901b0da91a5aa3982aab35852 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -18,12 +18,162 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["SinhArcsinh"] +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "SinhArcsinh", +] + + +def _sqrtx2p1(x): + """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" + return array_ops.where( + math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1., + math_ops.sqrt(x**2. + 1.), + # For large x, calculating x**2 can overflow. This can be alleviated by + # considering: + # sqrt(1 + x**2) + # = exp(0.5 log(1 + x**2)) + # = exp(0.5 log(x**2 * (1 + x**-2))) + # = exp(log(x) + 0.5 * log(1 + x**-2)) + # = |x| * exp(0.5 log(1 + x**-2)) + # = |x| * sqrt(1 + x**-2) + # We omit the last term in this approximation. + # When |x| > 1 / sqrt(machineepsilon), the second term will be 1, + # due to sqrt(1 + x**-2) = 1. This is also true with the gradient term, + # and higher order gradients, since the first order derivative of + # sqrt(1 + x**-2) is -2 * x**-3 / (1 + x**-2) = -2 / (x**3 + x), + # and all nth-order derivatives will be O(x**-(n + 2)). This makes any + # gradient terms that contain any derivatives of sqrt(1 + x**-2) vanish. + math_ops.abs(x)) + + +class SinhArcsinh(bijector.Bijector): + """Compute `Y = g(X) = Sinh( (Arcsinh(X) + skewness) * tailweight )`. + + For `skewness in (-inf, inf)` and `tailweight in (0, inf)`, this + transformation is a + diffeomorphism of the real line `(-inf, inf)`. The inverse transform is + `X = g^{-1}(Y) = Sinh( ArcSinh(Y) / tailweight - skewness )`. + + The `SinhArcsinh` transformation of the Normal is described in + [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865) + This Bijector allows a similar transformation of any distribution supported on + `(-inf, inf)`. + + #### Meaning of the parameters + + * If `skewness = 0` and `tailweight = 1`, this transform is the identity. + * Positive (negative) `skewness` leads to positive (negative) skew. + * positive skew means, for unimodal `X` centered at zero, the mode of `Y` is + "tilted" to the right. + * positive skew means positive values of `Y` become more likely, and + negative values become less likely. + * Larger (smaller) `tailweight` leads to fatter (thinner) tails. + * Fatter tails mean larger values of `|Y|` become more likely. + * If `X` is a unit Normal, `tailweight < 1` leads to a distribution that is + "flat" around `Y = 0`, and a very steep drop-off in the tails. + * If `X` is a unit Normal, `tailweight > 1` leads to a distribution more + peaked at the mode with heavier tails. + + To see the argument about the tails, note that for `|X| >> 1` and + `|X| >> (|skewness| * tailweight)**tailweight`, we have + `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. + """ + + def __init__(self, + skewness=None, + tailweight=None, + event_ndims=0, + validate_args=False, + name="SinhArcsinh"): + """Instantiates the `SinhArcsinh` bijector. + + Args: + skewness: Skewness parameter. Float-type `Tensor`. Default is `0` + of type `float32`. + tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as + `skewness` and broadcastable `shape`. Default is `1` of type `float32`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[skewness, tailweight]): + tailweight = 1. if tailweight is None else tailweight + skewness = 0. if skewness is None else skewness + self._skewness = ops.convert_to_tensor( + skewness, name="skewness") + self._tailweight = ops.convert_to_tensor( + tailweight, name="tailweight", dtype=self._skewness.dtype) + check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) + if validate_args: + self._tailweight = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._tailweight, + message="Argument tailweight was not positive") + ], self._tailweight) + super(SinhArcsinh, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def skewness(self): + """The `skewness` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" + return self._skewness + + @property + def tailweight(self): + """The `tailweight` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" + return self._tailweight + + def _forward(self, x): + return math_ops.sinh((math_ops.asinh(x) + self.skewness) * self.tailweight) + + def _inverse(self, y): + return math_ops.sinh(math_ops.asinh(y) / self.tailweight - self.skewness) + + def _inverse_log_det_jacobian(self, y): + # x = sinh(arcsinh(y) / tailweight - skewness) + # Using sinh' = cosh, arcsinh'(y) = 1 / sqrt(y**2 + 1), + # dx/dy + # = cosh(arcsinh(y) / tailweight - skewness) + # / (tailweight * sqrt(y**2 + 1)) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). + math_ops.log(math_ops.cosh( + math_ops.asinh(y) / self.tailweight - self.skewness) + # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases + # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). + / _sqrtx2p1(y)) + - math_ops.log(self.tailweight), + axis=event_dims) + + def _forward_log_det_jacobian(self, x): + # y = sinh((arcsinh(x) + skewness) * tailweight) + # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), + # dy/dx + # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) + event_dims = self._event_dims_tensor(x) + return math_ops.reduce_sum( + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). + math_ops.log(math_ops.cosh( + (math_ops.asinh(x) + self.skewness) * self.tailweight) + # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases + # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). + / _sqrtx2p1(x)) + + math_ops.log(self.tailweight), + axis=event_dims) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py deleted file mode 100644 index 3a75e4ae9495793901b0da91a5aa3982aab35852..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SinhArcsinh bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "SinhArcsinh", -] - - -def _sqrtx2p1(x): - """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" - return array_ops.where( - math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1., - math_ops.sqrt(x**2. + 1.), - # For large x, calculating x**2 can overflow. This can be alleviated by - # considering: - # sqrt(1 + x**2) - # = exp(0.5 log(1 + x**2)) - # = exp(0.5 log(x**2 * (1 + x**-2))) - # = exp(log(x) + 0.5 * log(1 + x**-2)) - # = |x| * exp(0.5 log(1 + x**-2)) - # = |x| * sqrt(1 + x**-2) - # We omit the last term in this approximation. - # When |x| > 1 / sqrt(machineepsilon), the second term will be 1, - # due to sqrt(1 + x**-2) = 1. This is also true with the gradient term, - # and higher order gradients, since the first order derivative of - # sqrt(1 + x**-2) is -2 * x**-3 / (1 + x**-2) = -2 / (x**3 + x), - # and all nth-order derivatives will be O(x**-(n + 2)). This makes any - # gradient terms that contain any derivatives of sqrt(1 + x**-2) vanish. - math_ops.abs(x)) - - -class SinhArcsinh(bijector.Bijector): - """Compute `Y = g(X) = Sinh( (Arcsinh(X) + skewness) * tailweight )`. - - For `skewness in (-inf, inf)` and `tailweight in (0, inf)`, this - transformation is a - diffeomorphism of the real line `(-inf, inf)`. The inverse transform is - `X = g^{-1}(Y) = Sinh( ArcSinh(Y) / tailweight - skewness )`. - - The `SinhArcsinh` transformation of the Normal is described in - [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865) - This Bijector allows a similar transformation of any distribution supported on - `(-inf, inf)`. - - #### Meaning of the parameters - - * If `skewness = 0` and `tailweight = 1`, this transform is the identity. - * Positive (negative) `skewness` leads to positive (negative) skew. - * positive skew means, for unimodal `X` centered at zero, the mode of `Y` is - "tilted" to the right. - * positive skew means positive values of `Y` become more likely, and - negative values become less likely. - * Larger (smaller) `tailweight` leads to fatter (thinner) tails. - * Fatter tails mean larger values of `|Y|` become more likely. - * If `X` is a unit Normal, `tailweight < 1` leads to a distribution that is - "flat" around `Y = 0`, and a very steep drop-off in the tails. - * If `X` is a unit Normal, `tailweight > 1` leads to a distribution more - peaked at the mode with heavier tails. - - To see the argument about the tails, note that for `|X| >> 1` and - `|X| >> (|skewness| * tailweight)**tailweight`, we have - `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. - """ - - def __init__(self, - skewness=None, - tailweight=None, - event_ndims=0, - validate_args=False, - name="SinhArcsinh"): - """Instantiates the `SinhArcsinh` bijector. - - Args: - skewness: Skewness parameter. Float-type `Tensor`. Default is `0` - of type `float32`. - tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as - `skewness` and broadcastable `shape`. Default is `1` of type `float32`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[skewness, tailweight]): - tailweight = 1. if tailweight is None else tailweight - skewness = 0. if skewness is None else skewness - self._skewness = ops.convert_to_tensor( - skewness, name="skewness") - self._tailweight = ops.convert_to_tensor( - tailweight, name="tailweight", dtype=self._skewness.dtype) - check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) - if validate_args: - self._tailweight = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._tailweight, - message="Argument tailweight was not positive") - ], self._tailweight) - super(SinhArcsinh, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) - - @property - def skewness(self): - """The `skewness` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" - return self._skewness - - @property - def tailweight(self): - """The `tailweight` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" - return self._tailweight - - def _forward(self, x): - return math_ops.sinh((math_ops.asinh(x) + self.skewness) * self.tailweight) - - def _inverse(self, y): - return math_ops.sinh(math_ops.asinh(y) / self.tailweight - self.skewness) - - def _inverse_log_det_jacobian(self, y): - # x = sinh(arcsinh(y) / tailweight - skewness) - # Using sinh' = cosh, arcsinh'(y) = 1 / sqrt(y**2 + 1), - # dx/dy - # = cosh(arcsinh(y) / tailweight - skewness) - # / (tailweight * sqrt(y**2 + 1)) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). - math_ops.log(math_ops.cosh( - math_ops.asinh(y) / self.tailweight - self.skewness) - # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases - # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). - / _sqrtx2p1(y)) - - math_ops.log(self.tailweight), - axis=event_dims) - - def _forward_log_det_jacobian(self, x): - # y = sinh((arcsinh(x) + skewness) * tailweight) - # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), - # dy/dx - # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). - math_ops.log(math_ops.cosh( - (math_ops.asinh(x) + self.skewness) * self.tailweight) - # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases - # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). - / _sqrtx2p1(x)) - + math_ops.log(self.tailweight), - axis=event_dims) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index be6608f97880ae68e10b17c815bf2d8438293261..e4a1d3dde230724e74d5076c5bba079590b94a70 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -18,12 +18,232 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["SoftmaxCentered"] +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "SoftmaxCentered", +] + + +class SoftmaxCentered(bijector.Bijector): + """Bijector which computes `Y = g(X) = exp([X 0]) / sum(exp([X 0]))`. + + To implement [softmax](https://en.wikipedia.org/wiki/Softmax_function) as a + bijection, the forward transformation appends a value to the input and the + inverse removes this coordinate. The appended coordinate represents a pivot, + e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last + coordinate. + + Because we append a coordinate, this bijector only supports `event_ndim in [0, + 1]`, i.e., scalars and vectors. + + Example Use: + + ```python + bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) + # Result: [0.2, 0.3, 0.4, 0.1] + # Extra result: 0.1 + + bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) + # Result: tf.log([2, 3, 4]) + # Extra coordinate removed. + ``` + + At first blush it may seem like the [Invariance of domain]( + https://en.wikipedia.org/wiki/Invariance_of_domain) theorem implies this + implementation is not a bijection. However, the appended dimension + makes the (forward) image non-open and the theorem does not directly apply. + """ + + def __init__(self, + event_ndims=0, + validate_args=False, + name="softmax_centered"): + self._graph_parents = [] + self._name = name + with self._name_scope("init", values=[event_ndims]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims is None or event_ndims not in [0, 1]: + raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") + self._static_event_ndims = event_ndims + super(SoftmaxCentered, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward_event_shape(self, input_shape): + if input_shape.ndims is None: + return input_shape + if input_shape.ndims != self._static_event_ndims: + raise ValueError("input_shape.dims = %d != %d" % + (input_shape.ndims, self._static_event_ndims)) + if input_shape.ndims == 0: + return tensor_shape.TensorShape([2]) + if input_shape.ndims == 1: + return tensor_shape.TensorShape(input_shape[0] + 1) + # Unreachable code: + raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) + + def _forward_event_shape_tensor(self, input_shape): + ndims = array_ops.shape(input_shape) + if self.validate_args: + # It is not possible for a negative shape so we need only check <= 1. + is_zero_or_one = check_ops.assert_equal( + ndims, 0 if self._static_event_ndims == 0 else 1, + message="event_ndims must be 0 or 1") + ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) + if self._static_event_ndims == 0: + return ops.convert_to_tensor( + [2], dtype=dtypes.int32, name="output_shape") + return input_shape + 1 + + def _inverse_event_shape(self, output_shape): + if output_shape.ndims is None: + return output_shape + if output_shape.ndims != 1: + raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) + if self._static_event_ndims == 0: + return tensor_shape.TensorShape([]) + return tensor_shape.TensorShape(output_shape[0] - 1) + + def _inverse_event_shape_tensor(self, output_shape): + ndims = array_ops.shape(output_shape)[0] + if self.validate_args: + # It is not possible for a negative shape so we need only check <= 1. + is_one = check_ops.assert_equal( + ndims, 1, message="event_ndims must be 1") + ndims = control_flow_ops.with_dependencies([is_one], ndims) + if self._static_event_ndims == 0: + return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") + return array_ops.expand_dims(output_shape[0] - 1, dim=0) + + def _forward(self, x): + # Pad the last dim with a zeros vector. We need this because it lets us + # infer the scale in the inverse function. + y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x + ndims = _get_ndims(y) + y = array_ops.pad(y, paddings=array_ops.one_hot(indices=[-1, ndims - 1], + depth=ndims, + axis=0, + dtype=dtypes.int32)) + # Set shape hints. + if x.shape.ndims is not None: + shape = x.shape.as_list() + if self._static_event_ndims == 0: + shape += [2] + elif shape[-1] is not None: + shape[-1] += 1 + shape = tensor_shape.TensorShape(shape) + y.shape.assert_is_compatible_with(shape) + y.set_shape(shape) + + # Since we only support event_ndims in [0, 1] and we do padding, we always + # reduce over the last dimension, i.e., dim=-1 (which is the default). + return nn_ops.softmax(y) + + def _inverse(self, y): + # To derive the inverse mapping note that: + # y[i] = exp(x[i]) / normalization + # and + # y[end] = 1 / normalization. + # Thus: + # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) + # = log(exp(x[i])/normalization) - log(y[end]) + # = log(y[i]) - log(y[end]) + shape = (np.asarray(y.shape.as_list(), dtype=np.int32) + if y.shape.is_fully_defined() + else array_ops.shape(y, name="shape")) + ndims = _get_ndims(y) + + # Do this first to make sure CSE catches that it'll happen again in + # _inverse_log_det_jacobian. + x = math_ops.log(y) + + # We now extract the last coordinate of the rightmost dimension. + # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. + begin = array_ops.one_hot(indices=ndims-1, + depth=ndims, + on_value=shape[-1]-np.array(1, dtype=shape.dtype), + dtype=shape.dtype) + size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) + log_normalization = -array_ops.strided_slice(x, begin, begin + size) + + # Here we slice out all but the last coordinate; see above for idea. + begin = array_ops.zeros_like(shape) + size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) + x = array_ops.strided_slice(x, begin, begin + size) + + x += log_normalization + + if self._static_event_ndims == 0: + x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) + + # Set shape hints. + if y.shape.ndims is not None: + shape = y.shape.as_list() + if self._static_event_ndims == 0: + shape = shape[:-1] + elif shape[-1] is not None: + shape[-1] -= 1 + shape = tensor_shape.TensorShape(shape) + x.shape.assert_is_compatible_with(shape) + x.set_shape(shape) + + return x + + def _inverse_log_det_jacobian(self, y): + # WLOG, consider the vector case: + # x = log(y[:-1]) - log(y[-1]) + # where, + # y[-1] = 1 - sum(y[:-1]). + # We have: + # det{ dX/dY } = det{ diag(1 ./ y[:-1]) + 1 / y[-1] } + # = det{ inv{ diag(y[:-1]) - y[:-1]' y[:-1] } } (1) + # = 1 / det{ diag(y[:-1]) - y[:-1]' y[:-1] } + # = 1 / { (1 + y[:-1]' inv(diag(y[:-1])) y[:-1]) * + # det(diag(y[:-1])) } (2) + # = 1 / { y[-1] prod(y[:-1]) } + # = 1 / prod(y) + # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula + # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector + # docstring "Tip". + # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma + return -math_ops.reduce_sum(math_ops.log(y), axis=-1) + + def _forward_log_det_jacobian(self, x): + if self._static_event_ndims == 0: + return x - 2. * nn_ops.softplus(x) + else: + # This code is similar to nn_ops.log_softmax but different because we have + # an implicit zero column to handle. I.e., instead of: + # reduce_sum(logits - reduce_sum(exp(logits), dim)) + # we must do: + # log_normalization = 1 + reduce_sum(exp(logits)) + # -log_normalization + reduce_sum(logits - log_normalization) + log_normalization = nn_ops.softplus( + math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + fldj = (-log_normalization + + math_ops.reduce_sum(x - log_normalization, + axis=-1, + keep_dims=True)) + return array_ops.squeeze(fldj, squeeze_dims=-1) + + +def _get_ndims(x): + """Returns `ndims`, statically if possible.""" + if x.shape.ndims is not None: + return x.shape.ndims + return array_ops.rank(x, name="ndims") diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py deleted file mode 100644 index 8645cc1b6b04be75a419342591272f07a4a1711c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SoftmaxCentered bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "SoftmaxCentered", -] - - -class SoftmaxCentered(bijector.Bijector): - """Bijector which computes `Y = g(X) = exp([X 0]) / sum(exp([X 0]))`. - - To implement [softmax](https://en.wikipedia.org/wiki/Softmax_function) as a - bijection, the forward transformation appends a value to the input and the - inverse removes this coordinate. The appended coordinate represents a pivot, - e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last - coordinate. - - Because we append a coordinate, this bijector only supports `event_ndim in [0, - 1]`, i.e., scalars and vectors. - - Example Use: - - ```python - bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) - # Result: [0.2, 0.3, 0.4, 0.1] - # Extra result: 0.1 - - bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) - # Result: tf.log([2, 3, 4]) - # Extra coordinate removed. - ``` - - At first blush it may seem like the [Invariance of domain]( - https://en.wikipedia.org/wiki/Invariance_of_domain) theorem implies this - implementation is not a bijection. However, the appended dimension - makes the (forward) image non-open and the theorem does not directly apply. - """ - - def __init__(self, - event_ndims=0, - validate_args=False, - name="softmax_centered"): - self._graph_parents = [] - self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 1]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") - self._static_event_ndims = event_ndims - super(SoftmaxCentered, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward_event_shape(self, input_shape): - if input_shape.ndims is None: - return input_shape - if input_shape.ndims != self._static_event_ndims: - raise ValueError("input_shape.dims = %d != %d" % - (input_shape.ndims, self._static_event_ndims)) - if input_shape.ndims == 0: - return tensor_shape.TensorShape([2]) - if input_shape.ndims == 1: - return tensor_shape.TensorShape(input_shape[0] + 1) - # Unreachable code: - raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) - - def _forward_event_shape_tensor(self, input_shape): - ndims = array_ops.shape(input_shape) - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_zero_or_one = check_ops.assert_equal( - ndims, 0 if self._static_event_ndims == 0 else 1, - message="event_ndims must be 0 or 1") - ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor( - [2], dtype=dtypes.int32, name="output_shape") - return input_shape + 1 - - def _inverse_event_shape(self, output_shape): - if output_shape.ndims is None: - return output_shape - if output_shape.ndims != 1: - raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) - if self._static_event_ndims == 0: - return tensor_shape.TensorShape([]) - return tensor_shape.TensorShape(output_shape[0] - 1) - - def _inverse_event_shape_tensor(self, output_shape): - ndims = array_ops.shape(output_shape)[0] - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_one = check_ops.assert_equal( - ndims, 1, message="event_ndims must be 1") - ndims = control_flow_ops.with_dependencies([is_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") - return array_ops.expand_dims(output_shape[0] - 1, dim=0) - - def _forward(self, x): - # Pad the last dim with a zeros vector. We need this because it lets us - # infer the scale in the inverse function. - y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x - ndims = (y.get_shape().ndims if y.get_shape().ndims is not None - else array_ops.rank(y)) - y = array_ops.pad(y, - paddings=array_ops.concat( - (array_ops.zeros( - (ndims - 1, 2), dtype=dtypes.int32), [[0, 1]]), - 0)) - - # Set shape hints. - if x.get_shape().ndims is not None: - shape = x.get_shape().as_list() - if self._static_event_ndims == 0: - shape += [2] - elif shape[-1] is not None: - shape[-1] += 1 - shape = tensor_shape.TensorShape(shape) - y.get_shape().assert_is_compatible_with(shape) - y.set_shape(shape) - - # Since we only support event_ndims in [0, 1] and we do padding, we always - # reduce over the last dimension, i.e., dim=-1 (which is the default). - return nn_ops.softmax(y) - - def _inverse(self, y): - # To derive the inverse mapping note that: - # y[i] = exp(x[i]) / normalization - # and - # y[end] = 1 / normalization. - # Thus: - # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) - # = log(exp(x[i])/normalization) - log(y[end]) - # = log(y[i]) - log(y[end]) - shape = (np.asarray(y.get_shape().as_list(), dtype=np.int32) - if y.get_shape().is_fully_defined() - else array_ops.shape(y, name="shape")) - ndims = y.get_shape().ndims or math_ops.rank(y, name="ndims") - - # Do this first to make sure CSE catches that it'll happen again in - # _inverse_log_det_jacobian. - x = math_ops.log(y) - - # We now extract the last coordinate of the rightmost dimension. - # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. - begin = array_ops.one_hot(indices=ndims-1, - depth=ndims, - on_value=shape[-1]-np.array(1, dtype=shape.dtype), - dtype=shape.dtype) - size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) - log_normalization = -array_ops.strided_slice(x, begin, begin + size) - - # Here we slice out all but the last coordinate; see above for idea. - begin = array_ops.zeros_like(shape) - size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) - x = array_ops.strided_slice(x, begin, begin + size) - - x += log_normalization - - if self._static_event_ndims == 0: - x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) - - # Set shape hints. - if y.get_shape().ndims is not None: - shape = y.get_shape().as_list() - if self._static_event_ndims == 0: - shape = shape[:-1] - elif shape[-1] is not None: - shape[-1] -= 1 - shape = tensor_shape.TensorShape(shape) - x.get_shape().assert_is_compatible_with(shape) - x.set_shape(shape) - - return x - - def _inverse_log_det_jacobian(self, y): - # WLOG, consider the vector case: - # x = log(y[:-1]) - log(y[-1]) - # where, - # y[-1] = 1 - sum(y[:-1]). - # We have: - # det{ dX/dY } = det{ diag(1 ./ y[:-1]) + 1 / y[-1] } - # = det{ inv{ diag(y[:-1]) - y[:-1]' y[:-1] } } (1) - # = 1 / det{ diag(y[:-1]) - y[:-1]' y[:-1] } - # = 1 / { (1 + y[:-1]' inv(diag(y[:-1])) y[:-1]) * - # det(diag(y[:-1])) } (2) - # = 1 / { y[-1] prod(y[:-1]) } - # = 1 / prod(y) - # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula - # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector - # docstring "Tip". - # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma - return -math_ops.reduce_sum(math_ops.log(y), axis=-1) - - def _forward_log_det_jacobian(self, x): - if self._static_event_ndims == 0: - return x - 2. * nn_ops.softplus(x) - else: - # This code is similar to nn_ops.log_softmax but different because we have - # an implicit zero column to handle. I.e., instead of: - # reduce_sum(logits - reduce_sum(exp(logits), dim)) - # we must do: - # log_normalization = 1 + reduce_sum(exp(logits)) - # -log_normalization + reduce_sum(logits - log_normalization) - log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) - fldj = (-log_normalization + - math_ops.reduce_sum(x - log_normalization, - axis=-1, - keep_dims=True)) - return array_ops.squeeze(fldj, squeeze_dims=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 250a1144b53bb43271ff7ee494604d9bae6feda8..81957fcf78922fa15fd20a25d144071f431161ae 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -18,12 +18,127 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.softplus_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -_allowed_symbols = ["Softplus"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Softplus", +] + + +class Softplus(bijector.Bijector): + """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. + + The softplus `Bijector` has the following two useful properties: + + * The domain is the positive real numbers + * `softplus(x) approx x`, for large `x`, so it does not overflow as easily as + the `Exp` `Bijector`. + + The optional nonzero `hinge_softness` parameter changes the transition at + zero. With `hinge_softness = c`, the bijector is: + + ```f_c(x) := c * g(x / c) = c * Log[1 + exp(x / c)].``` + + For large `x >> 1`, `c * Log[1 + exp(x / c)] approx c * Log[exp(x / c)] = x`, + so the behavior for large `x` is the same as the standard softplus. + + As `c > 0` approaches 0 from the right, `f_c(x)` becomes less and less soft, + approaching `max(0, x)`. + + * `c = 1` is the default. + * `c > 0` but small means `f(x) approx ReLu(x) = max(0, x)`. + * `c < 0` flips sign and reflects around the `y-axis`: `f_{-c}(x) = -f_c(-x)`. + * `c = 0` results in a non-bijective transformation and triggers an exception. + + Example Use: + + ```python + # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 + # batch ndim and 2 event ndims (i.e., vector of matrices). + softplus = Softplus(event_ndims=2) + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + log(1 + exp(x)) == softplus.forward(x) + log(exp(x) - 1) == softplus.inverse(x) + ``` + + Note: log(.) and exp(.) are applied element-wise but the Jacobian is a + reduction over the event space. + """ + + @distribution_util.AppendDocstring( + kwargs_dict={ + "hinge_softness": ( + "Nonzero floating point `Tensor`. Controls the softness of what " + "would otherwise be a kink at the origin. Default is 1.0")}) + def __init__(self, + event_ndims=0, + hinge_softness=None, + validate_args=False, + name="softplus"): + with ops.name_scope(name, values=[hinge_softness]): + if hinge_softness is not None: + self._hinge_softness = ops.convert_to_tensor( + hinge_softness, name="hinge_softness") + else: + self._hinge_softness = None + if validate_args: + nonzero_check = check_ops.assert_none_equal( + ops.convert_to_tensor( + 0, dtype=self.hinge_softness.dtype), + self.hinge_softness, + message="hinge_softness must be non-zero") + self._hinge_softness = control_flow_ops.with_dependencies( + [nonzero_check], self.hinge_softness) + + super(Softplus, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self.hinge_softness is None: + return nn_ops.softplus(x) + hinge_softness = math_ops.cast(self.hinge_softness, x.dtype) + return hinge_softness * nn_ops.softplus(x / hinge_softness) + + def _inverse(self, y): + if self.hinge_softness is None: + return distribution_util.softplus_inverse(y) + hinge_softness = math_ops.cast(self.hinge_softness, y.dtype) + return hinge_softness * distribution_util.softplus_inverse( + y / hinge_softness) + + def _inverse_log_det_jacobian(self, y): + # Could also do: + # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), + # axis=event_dims) + # but the following is more numerically stable. Ie, + # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] + # ==> dX/dY = exp{Y} / (exp{Y} - 1) + # = 1 / (1 - exp{-Y}), + # which is the most stable for large Y > 0. For small Y, we use + # 1 - exp{-Y} approx Y. + if self.hinge_softness is not None: + y /= math_ops.cast(self.hinge_softness, y.dtype) + return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), + axis=self._event_dims_tensor(y)) + + def _forward_log_det_jacobian(self, x): + if self.hinge_softness is not None: + x /= math_ops.cast(self.hinge_softness, x.dtype) + return -math_ops.reduce_sum(nn_ops.softplus(-x), + axis=self._event_dims_tensor(x)) + + @property + def hinge_softness(self): + return self._hinge_softness diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py deleted file mode 100644 index 81957fcf78922fa15fd20a25d144071f431161ae..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Softplus bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = [ - "Softplus", -] - - -class Softplus(bijector.Bijector): - """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. - - The softplus `Bijector` has the following two useful properties: - - * The domain is the positive real numbers - * `softplus(x) approx x`, for large `x`, so it does not overflow as easily as - the `Exp` `Bijector`. - - The optional nonzero `hinge_softness` parameter changes the transition at - zero. With `hinge_softness = c`, the bijector is: - - ```f_c(x) := c * g(x / c) = c * Log[1 + exp(x / c)].``` - - For large `x >> 1`, `c * Log[1 + exp(x / c)] approx c * Log[exp(x / c)] = x`, - so the behavior for large `x` is the same as the standard softplus. - - As `c > 0` approaches 0 from the right, `f_c(x)` becomes less and less soft, - approaching `max(0, x)`. - - * `c = 1` is the default. - * `c > 0` but small means `f(x) approx ReLu(x) = max(0, x)`. - * `c < 0` flips sign and reflects around the `y-axis`: `f_{-c}(x) = -f_c(-x)`. - * `c = 0` results in a non-bijective transformation and triggers an exception. - - Example Use: - - ```python - # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - softplus = Softplus(event_ndims=2) - x = [[[1., 2], - [3, 4]], - [[5, 6], - [7, 8]]] - log(1 + exp(x)) == softplus.forward(x) - log(exp(x) - 1) == softplus.inverse(x) - ``` - - Note: log(.) and exp(.) are applied element-wise but the Jacobian is a - reduction over the event space. - """ - - @distribution_util.AppendDocstring( - kwargs_dict={ - "hinge_softness": ( - "Nonzero floating point `Tensor`. Controls the softness of what " - "would otherwise be a kink at the origin. Default is 1.0")}) - def __init__(self, - event_ndims=0, - hinge_softness=None, - validate_args=False, - name="softplus"): - with ops.name_scope(name, values=[hinge_softness]): - if hinge_softness is not None: - self._hinge_softness = ops.convert_to_tensor( - hinge_softness, name="hinge_softness") - else: - self._hinge_softness = None - if validate_args: - nonzero_check = check_ops.assert_none_equal( - ops.convert_to_tensor( - 0, dtype=self.hinge_softness.dtype), - self.hinge_softness, - message="hinge_softness must be non-zero") - self._hinge_softness = control_flow_ops.with_dependencies( - [nonzero_check], self.hinge_softness) - - super(Softplus, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - if self.hinge_softness is None: - return nn_ops.softplus(x) - hinge_softness = math_ops.cast(self.hinge_softness, x.dtype) - return hinge_softness * nn_ops.softplus(x / hinge_softness) - - def _inverse(self, y): - if self.hinge_softness is None: - return distribution_util.softplus_inverse(y) - hinge_softness = math_ops.cast(self.hinge_softness, y.dtype) - return hinge_softness * distribution_util.softplus_inverse( - y / hinge_softness) - - def _inverse_log_det_jacobian(self, y): - # Could also do: - # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), - # axis=event_dims) - # but the following is more numerically stable. Ie, - # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] - # ==> dX/dY = exp{Y} / (exp{Y} - 1) - # = 1 / (1 - exp{-Y}), - # which is the most stable for large Y > 0. For small Y, we use - # 1 - exp{-Y} approx Y. - if self.hinge_softness is not None: - y /= math_ops.cast(self.hinge_softness, y.dtype) - return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), - axis=self._event_dims_tensor(y)) - - def _forward_log_det_jacobian(self, x): - if self.hinge_softness is not None: - x /= math_ops.cast(self.hinge_softness, x.dtype) - return -math_ops.reduce_sum(nn_ops.softplus(-x), - axis=self._event_dims_tensor(x)) - - @property - def hinge_softness(self): - return self._hinge_softness diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index d439f28884d8bd7f2b808317e10c5b5e44bfcfa2..00520bcda85e9527767e6342bf75f10667c264a8 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -18,12 +18,132 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.weibull_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Weibull"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Weibull", +] + + +class Weibull(bijector.Bijector): + """Compute `Y = g(X) = 1 - exp((-X / scale) ** concentration), X >= 0`. + + This bijector maps inputs from `[0, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution): + + ```none + Y ~ Weibull(scale, concentration) + pdf(y; scale, concentration, y >= 0) = (scale / concentration) * ( + scale / concentration) ** (concentration - 1) * exp( + -(y / scale) ** concentration) + ``` + """ + + def __init__(self, + scale=1., + concentration=1., + event_ndims=0, + validate_args=False, + name="weibull"): + """Instantiates the `Weibull` bijector. + + Args: + scale: Positive Float-type `Tensor` that is the same dtype and is + broadcastable with `concentration`. + This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`. + concentration: Positive Float-type `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[scale, concentration]): + self._scale = ops.convert_to_tensor(scale, name="scale") + self._concentration = ops.convert_to_tensor( + concentration, name="concentration") + check_ops.assert_same_float_dtype([self._scale, self._concentration]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, + message="Argument scale was not positive") + ], self._scale) + self._concentration = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._concentration, + message="Argument concentration was not positive") + ], self._concentration) + + super(Weibull, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + @property + def scale(self): + """The `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" + return self._scale + + @property + def concentration(self): + """The `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" + return self._concentration + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + return -math_ops.expm1(-((x / self.scale) ** self.concentration)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.scale * (-math_ops.log1p(-y)) ** (1 / self.concentration) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + -math_ops.log1p(-y) + + (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + + math_ops.log(self.scale / self.concentration), + axis=event_dims) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid_x(x) + event_dims = self._event_dims_tensor(x) + return math_ops.reduce_sum( + -(x / self.scale) ** self.concentration + + (self.concentration - 1) * math_ops.log(x) + + math_ops.log(self.concentration) + + -self.concentration * math_ops.log(self.scale), + axis=event_dims) + + def _maybe_assert_valid_x(self, x): + if not self.validate_args: + return x + is_valid = check_ops.assert_non_negative( + x, + message="Forward transformation input must be at least {}.".format(0)) + return control_flow_ops.with_dependencies([is_valid], x) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py deleted file mode 100644 index 00520bcda85e9527767e6342bf75f10667c264a8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Weibull bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Weibull", -] - - -class Weibull(bijector.Bijector): - """Compute `Y = g(X) = 1 - exp((-X / scale) ** concentration), X >= 0`. - - This bijector maps inputs from `[0, inf]` to [0, 1]`. The inverse of the - bijector applied to a uniform random variable `X ~ U(0, 1) gives back a - random variable with the - [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution): - - ```none - Y ~ Weibull(scale, concentration) - pdf(y; scale, concentration, y >= 0) = (scale / concentration) * ( - scale / concentration) ** (concentration - 1) * exp( - -(y / scale) ** concentration) - ``` - """ - - def __init__(self, - scale=1., - concentration=1., - event_ndims=0, - validate_args=False, - name="weibull"): - """Instantiates the `Weibull` bijector. - - Args: - scale: Positive Float-type `Tensor` that is the same dtype and is - broadcastable with `concentration`. - This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - concentration: Positive Float-type `Tensor` that is the same dtype and is - broadcastable with `scale`. - This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[scale, concentration]): - self._scale = ops.convert_to_tensor(scale, name="scale") - self._concentration = ops.convert_to_tensor( - concentration, name="concentration") - check_ops.assert_same_float_dtype([self._scale, self._concentration]) - if validate_args: - self._scale = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._scale, - message="Argument scale was not positive") - ], self._scale) - self._concentration = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._concentration, - message="Argument concentration was not positive") - ], self._concentration) - - super(Weibull, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - @property - def scale(self): - """The `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" - return self._scale - - @property - def concentration(self): - """The `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" - return self._concentration - - def _forward(self, x): - x = self._maybe_assert_valid_x(x) - return -math_ops.expm1(-((x / self.scale) ** self.concentration)) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - return self.scale * (-math_ops.log1p(-y)) ** (1 / self.concentration) - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - -math_ops.log1p(-y) + - (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + - math_ops.log(self.scale / self.concentration), - axis=event_dims) - - def _forward_log_det_jacobian(self, x): - x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - -(x / self.scale) ** self.concentration + - (self.concentration - 1) * math_ops.log(x) + - math_ops.log(self.concentration) + - -self.concentration * math_ops.log(self.scale), - axis=event_dims) - - def _maybe_assert_valid_x(self, x): - if not self.validate_args: - return x - is_valid = check_ops.assert_non_negative( - x, - message="Forward transformation input must be at least {}.".format(0)) - return control_flow_ops.with_dependencies([is_valid], x) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_positive = check_ops.assert_non_negative( - y, message="Inverse transformation input must be greater than 0.") - less_than_one = check_ops.assert_less_equal( - y, constant_op.constant(1., y.dtype), - message="Inverse transformation input must be less than or equal to 1.") - return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index a17bb091f69b651d21f70a25c5aab61b203e62de..6f5d724a2a945ed8f9c159d8314327c6f994d1db 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -30,7 +30,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution - __all__ = [ "Cauchy", ] @@ -44,16 +43,17 @@ class Cauchy(distribution.Distribution): The probability density function (pdf) is, ```none - pdf(x; loc, scale) = 1 / (pi * scale * (1 + ((x - loc) / scale)**2)) + pdf(x; loc, scale) = 1 / (pi scale (1 + z**2)) + z = (x - loc) / scale ``` where `loc` is the location, and `scale` is the scale. The Cauchy distribution is a member of the [location-scale family]( https://en.wikipedia.org/wiki/Location-scale_family), i.e. + `Y ~ Cauchy(loc, scale)` is equivalent to, ```none X ~ Cauchy(loc=0, scale=1) - Y ~ Cauchy(loc=loc, scale=scale) Y = loc + scale * X ``` @@ -62,14 +62,16 @@ class Cauchy(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Cauchy distribution. - dist = Cauchy(loc=0., scale=3.) + dist = tfd.Cauchy(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Cauchy distributions. - dist = Cauchy(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Cauchy(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -77,18 +79,17 @@ class Cauchy(distribution.Distribution): # Get 3 samples, returning a 3 x 2 tensor. dist.sample([3]) - ``` - - Arguments are broadcast when possible. - ```python + # Arguments are broadcast when possible. # Define a batch of two scalar valued Cauchy distributions. # Both have median 1, but different scales. - dist = tf.contrib.distributions.Cauchy(loc=1., scale=[11, 22.]) + dist = tfd.Cauchy(loc=1., scale=[11, 22.]) + # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. - dist.prob(3.0) + dist.prob(3.) ``` + """ def __init__(self, @@ -97,7 +98,7 @@ class Cauchy(distribution.Distribution): validate_args=False, allow_nan_stats=True, name="Cauchy"): - """Construct Cauchy distributions with loc and and scale `loc` and `scale`. + """Construct Cauchy distributions. The parameters `loc` and `scale` must be shaped in a way that supports broadcasting (e.g. `loc + scale` is a valid operation). @@ -121,8 +122,8 @@ class Cauchy(distribution.Distribution): """ parameters = locals() with ops.name_scope(name, values=[loc, scale]): - with ops.control_dependencies([check_ops.assert_positive(scale)] if - validate_args else []): + with ops.control_dependencies([check_ops.assert_positive(scale)] + if validate_args else []): self._loc = array_ops.identity(loc, name="loc") self._scale = array_ops.identity(scale, name="scale") check_ops.assert_same_float_dtype([self._loc, self._scale]) @@ -138,8 +139,8 @@ class Cauchy(distribution.Distribution): @staticmethod def _param_shapes(sample_shape): return dict( - zip(("loc", "scale"), ([ops.convert_to_tensor( - sample_shape, dtype=dtypes.int32)] * 2))) + zip(("loc", "scale"), + ([ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2))) @property def loc(self): @@ -153,13 +154,10 @@ class Cauchy(distribution.Distribution): def _batch_shape_tensor(self): return array_ops.broadcast_dynamic_shape( - array_ops.shape(self.loc), - array_ops.shape(self.scale)) + array_ops.shape(self.loc), array_ops.shape(self.scale)) def _batch_shape(self): - return array_ops.broadcast_static_shape( - self.loc.shape, - self.scale.shape) + return array_ops.broadcast_static_shape(self.loc.shape, self.scale.shape) def _event_shape_tensor(self): return constant_op.constant([], dtype=dtypes.int32) diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 599c855cda434d9249187d5d154d50a8a8c49a6c..1d4c5660d8d73b7b6a7e758fc834ccfddeb5c8ea 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -121,7 +121,7 @@ class ConditionalTransformedDistribution( log_prob = self.distribution.log_prob(x, **distribution_kwargs) if self._is_maybe_event_override: log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) - return ildj + log_prob + return math_ops.cast(ildj, log_prob.dtype) + log_prob @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _prob(self, y, bijector_kwargs=None, distribution_kwargs=None): @@ -143,7 +143,7 @@ class ConditionalTransformedDistribution( prob = self.distribution.prob(x, **distribution_kwargs) if self._is_maybe_event_override: prob = math_ops.reduce_prod(prob, self._reduce_event_indices) - return math_ops.exp(ildj) * prob + return math_ops.exp(math_ops.cast(ildj, prob.dtype)) * prob @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _log_cdf(self, y, bijector_kwargs=None, distribution_kwargs=None): diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index 850d08d1bd69ebc7661557d648e2bffe77e6a908..8049522e9f5dc26b244b7e710a9ae8b981efd6b6 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -290,8 +290,10 @@ class VectorDeterministic(_BaseDeterministic): #### Examples ```python + tfd = tf.contrib.distributions + # Initialize a single VectorDeterministic supported at [0., 2.] in R^2. - constant = tf.contrib.distributions.Deterministic([0., 2.]) + constant = tfd.Deterministic([0., 2.]) constant.prob([0., 2.]) ==> 1. constant.prob([0., 3.]) @@ -299,7 +301,7 @@ class VectorDeterministic(_BaseDeterministic): # Initialize a [3] batch of constants on R^2. loc = [[0., 1.], [2., 3.], [4., 5.]] - constant = constant_lib.VectorDeterministic(loc) + constant = tfd.VectorDeterministic(loc) constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]]) ==> [1., 0., 0.] ``` diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 869b5698e57d199755ce1686a74a1eafe3b73e7d..a4d249d41ec9733721a3583d3708e0da56db1733 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -19,9 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import linalg -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -330,54 +328,14 @@ def shapes_from_loc_and_scale(loc, scale, name="shapes_from_loc_and_scale"): else: loc_batch_shape = ops.convert_to_tensor(loc_batch_shape, name="loc_batch_shape") + # This is defined in the core util module. + # pylint: disable=undefined-variable batch_shape = prefer_static_broadcast_shape(batch_shape, loc_batch_shape) + # pylint: enable=undefined-variable return batch_shape, event_shape -def prefer_static_broadcast_shape( - shape1, shape2, name="prefer_static_broadcast_shape"): - """Convenience function which statically broadcasts shape when possible. - - Args: - shape1: `1-D` integer `Tensor`. Already converted to tensor! - shape2: `1-D` integer `Tensor`. Already converted to tensor! - name: A string name to prepend to created ops. - - Returns: - The broadcast shape, either as `TensorShape` (if broadcast can be done - statically), or as a `Tensor`. - """ - with ops.name_scope(name, values=[shape1, shape2]): - def make_shape_tensor(x): - return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32) - - def get_tensor_shape(s): - if isinstance(s, tensor_shape.TensorShape): - return s - s_ = tensor_util.constant_value(make_shape_tensor(s)) - if s_ is not None: - return tensor_shape.TensorShape(s_) - return None - - def get_shape_tensor(s): - if not isinstance(s, tensor_shape.TensorShape): - return make_shape_tensor(s) - if s.is_fully_defined(): - return make_shape_tensor(s.as_list()) - raise ValueError("Cannot broadcast from partially " - "defined `TensorShape`.") - - shape1_ = get_tensor_shape(shape1) - shape2_ = get_tensor_shape(shape2) - if shape1_ is not None and shape2_ is not None: - return array_ops.broadcast_static_shape(shape1_, shape2_) - - shape1_ = get_shape_tensor(shape1) - shape2_ = get_shape_tensor(shape2) - return array_ops.broadcast_dynamic_shape(shape1_, shape2_) - - def get_broadcast_shape(*tensors): """Get broadcast shape as a Python list of integers (preferred) or `Tensor`. diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index ba8d3c639b397422f0f6210ba9f48650f0da1e3e..d0efaefb8e78ddf4436e9e5a112d2c1cdddaf3b5 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -62,15 +62,17 @@ class _Gumbel(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Gumbel distribution. - dist = tf.contrib.distributions.Gumbel(loc=0., scale=3.) + dist = tfd.Gumbel(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Gumbels. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Gumbel(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Gumbel(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -85,7 +87,7 @@ class _Gumbel(distribution.Distribution): ```python # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Gumbel(loc=1., scale=[11, 22.]) + dist = tfd.Gumbel(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0751a6e0b78cb3d79bd3478e740bb05cd26428 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -0,0 +1,171 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Half Normal distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import special_math + + +__all__ = [ + "HalfNormal", +] + + +class HalfNormal(distribution.Distribution): + """The Half Normal distribution with scale `scale`. + + #### Mathematical details + + The half normal is a transformation of a centered normal distribution. + If some random variable `X` has normal distribution, + ```none + X ~ Normal(0.0, scale) + Y = |X| + ``` + Then `Y` will have half normal distribution. The probability density + function (pdf) is: + + ```none + pdf(x; scale, x > 0) = sqrt(2) / (scale * sqrt(pi)) * + exp(- 1/2 * (x / scale) ** 2) + ) + ``` + Where `scale = sigma` is the standard deviation of the underlying normal + distribution. + + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + # Define a single scalar HalfNormal distribution. + dist = tf.contrib.distributions.HalfNormal(scale=3.0) + + # Evaluate the cdf at 1, returning a scalar. + dist.cdf(1.) + + # Define a batch of two scalar valued HalfNormals. + # The first has scale 11.0, the second 22.0 + dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0]) + + # Evaluate the pdf of the first distribution on 1.0, and the second on 1.5, + # returning a length two tensor. + dist.prob([1.0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + ``` + + """ + + def __init__(self, + scale, + validate_args=False, + allow_nan_stats=True, + name="HalfNormal"): + """Construct HalfNormals with scale `scale`. + + Args: + scale: Floating point tensor; the scales of the distribution(s). + Must contain only positive values. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + parameters = locals() + with ops.name_scope(name, values=[scale]): + with ops.control_dependencies([check_ops.assert_positive(scale)] if + validate_args else []): + self._scale = array_ops.identity(scale, name="scale") + super(HalfNormal, self).__init__( + dtype=self._scale.dtype, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._scale], + name=name) + + @staticmethod + def _param_shapes(sample_shape): + return {"scale": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} + + @property + def scale(self): + """Distribution parameter for the scale.""" + return self._scale + + def _batch_shape_tensor(self): + return array_ops.shape(self.scale) + + def _batch_shape(self): + return self.scale.shape + + def _event_shape_tensor(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) + sampled = random_ops.random_normal( + shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) + return math_ops.abs(sampled * self.scale) + + def _prob(self, x): + coeff = np.sqrt(2) / self.scale / np.sqrt(np.pi) + pdf = coeff * math_ops.exp(- 0.5 * (x / self.scale) ** 2) + return pdf * math_ops.cast(x >= 0, self.dtype) + + def _cdf(self, x): + truncated_x = nn.relu(x) + return math_ops.erf(truncated_x / self.scale / np.sqrt(2.0)) + + def _entropy(self): + return 0.5 * math_ops.log(np.pi * self.scale ** 2.0 / 2.0) + 0.5 + + def _mean(self): + return self.scale * np.sqrt(2.0) / np.sqrt(np.pi) + + def _quantile(self, p): + return np.sqrt(2.0) * self.scale * special_math.erfinv(p) + + def _mode(self): + return array_ops.zeros(self.batch_shape_tensor()) + + def _variance(self): + return self.scale ** 2.0 * (1.0 - 2.0 / np.pi) diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index 6a74ca9a0ae1ad30081d21cc15a65be052a99e2a..cbce005013281ff3c58c94d525d5ce7a865d725a 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -68,11 +68,11 @@ class Independent(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Make independent distribution from a 2-batch Normal. - ind = ds.Independent( - distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]), + ind = tfd.Independent( + distribution=tfd.Normal(loc=[-1., 1], scale=[0.1, 0.5]), reinterpreted_batch_ndims=1) # All batch dims have been "absorbed" into event dims. @@ -80,8 +80,8 @@ class Independent(distribution_lib.Distribution): ind.event_shape # ==> [2] # Make independent distribution from a 2-batch bivariate Normal. - ind = ds.Independent( - distribution=ds.MultivariateNormalDiag( + ind = tfd.Independent( + distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]), reinterpreted_batch_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 956dee38a378813434656a28a69c89b6ec1e8b72..ee4d86867d48b20e97757bcec57d452085814b80 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -88,8 +88,9 @@ class InverseGamma(distribution.Distribution): #### Examples ```python - dist = InverseGamma(concentration=3.0, rate=2.0) - dist2 = InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + tfd = tf.contrib.distributions + dist = tfd.InverseGamma(concentration=3.0, rate=2.0) + dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) ``` """ diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 48794a48828fe796e233e968d8c755136ce166ad..473677f8d91b184e029f345bb05f5c5d63df7a40 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -60,15 +60,17 @@ class Logistic(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Logistic distribution. - dist = tf.contrib.distributions.Logistic(loc=0., scale=3.) + dist = tfd.Logistic(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Logistics. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Logistic(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Logistic(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -76,14 +78,11 @@ class Logistic(distribution.Distribution): # Get 3 samples, returning a 3 x 2 tensor. dist.sample([3]) - ``` - Arguments are broadcast when possible. - - ```python + # Arguments are broadcast when possible. # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Logistic(loc=1., scale=[11, 22.]) + dist = tfd.Logistic(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index e676931d9145e72907d990148ee2d180e0da0258..f2d492f5489a197157558ae727416b51db04793e 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -49,13 +49,13 @@ class Mixture(distribution.Distribution): ```python # Create a mixture of two Gaussians: - ds = tf.contrib.distributions + tfd = tf.contrib.distributions mix = 0.3 - bimix_gauss = ds.Mixture( - cat=ds.Categorical(probs=[mix, 1.-mix]), + bimix_gauss = tfd.Mixture( + cat=tfd.Categorical(probs=[mix, 1.-mix]), components=[ - ds.Normal(loc=-1., scale=0.1), - ds.Normal(loc=+1., scale=0.5), + tfd.Normal(loc=-1., scale=0.1), + tfd.Normal(loc=+1., scale=0.5), ]) # Plot the PDF. diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 5558ef0f255db684b229d129666634e50c625887..0ca236c3761f9d3a0fcc79ff9db792319108db0d 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -43,15 +43,14 @@ class MixtureSameFamily(distribution.Distribution): #### Examples ```python - import matplotlib.pyplot as plt - ds = tf.contrib.distributions + tfd = tf.contrib.distributions ### Create a mixture of two scalar Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.Normal( + components_distribution=tfd.Normal( loc=[-1., 1], # One for each component. scale=[0.1, 0.5])) # And same here. @@ -63,14 +62,15 @@ class MixtureSameFamily(distribution.Distribution): # Plot PDF. x = np.linspace(-2., 3., int(1e4), dtype=np.float32) + import matplotlib.pyplot as plt plt.plot(x, gm.prob(x).eval()); ### Create a mixture of two Bivariate Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.MultivariateNormalDiag( + components_distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], # component 1 [1, -1]], # component 2 scale_identity_multiplier=[.3, .6])) @@ -320,13 +320,14 @@ class MixtureSameFamily(distribution.Distribution): return array_ops.shape(d.batch_shape_tensor())[0] dist_batch_ndims = _get_ndims(self) cat_batch_ndims = _get_ndims(self.mixture_distribution) - bnd = distribution_util.pick_vector( + pad_ndims = array_ops.where( self.mixture_distribution.is_scalar_batch(), - [dist_batch_ndims], [cat_batch_ndims])[0] + dist_batch_ndims, + dist_batch_ndims - cat_batch_ndims) s = array_ops.shape(x) x = array_ops.reshape(x, shape=array_ops.concat([ s[:-1], - array_ops.ones([bnd], dtype=dtypes.int32), + array_ops.ones([pad_ndims], dtype=dtypes.int32), s[-1:], array_ops.ones([self._event_ndims], dtype=dtypes.int32), ], axis=0)) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index 163cf75d990d5fe7ec1e3aaf0040fc71f61774a7..e862552880f4073c8fa8e90134d0633e7484b0bf 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -84,10 +84,10 @@ class MultivariateNormalDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -101,7 +101,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -119,7 +119,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate Gaussians. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 040bc230722194316b8a74627344e315a2578281..413e88f03ae0286c294f3404549a73e1a47dcff7 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -86,7 +86,7 @@ class MultivariateNormalDiagPlusLowRank( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`, # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is @@ -97,7 +97,7 @@ class MultivariateNormalDiagPlusLowRank( [-1, 1], [2, -0.5]] # shape: [3, 2] m = [4., 5] # shape: [2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu scale_diag=d scale_perturb_factor=U, @@ -118,7 +118,7 @@ class MultivariateNormalDiagPlusLowRank( m = [[0.1, 0.2], [0.4, 0.5]] # shape: [b, r] = [2, 2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu, scale_perturb_factor=U, scale_perturb_diag=m) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index f9952b2069d6dfd2593e6bd71ede0badf44cdf98..00a18569fce0175ee39e433dfad796e5f21fe8a4 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -18,12 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.distributions.python.ops import mvn_tril from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops __all__ = [ @@ -73,14 +76,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] cov = [[ 0.36, 0.12, 0.06], [ 0.12, 0.29, -0.13], [ 0.06, -0.13, 0.26]] - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance_matrix=cov) @@ -100,7 +103,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] covariance_matrix = ... # shape: [2, 3, 3], symmetric, positive definite. - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance=covariance_matrix) @@ -167,9 +170,12 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): covariance_matrix = ops.convert_to_tensor( covariance_matrix, name="covariance_matrix") if validate_args: - assert_symmetric = check_ops.assert_equal( - covariance_matrix, - array_ops.matrix_transpose(covariance_matrix), + tol = np.finfo(covariance_matrix.dtype.as_numpy_dtype).eps * 10 + diff = math_ops.abs( + covariance_matrix + - array_ops.matrix_transpose(covariance_matrix)) + assert_symmetric = check_ops.assert_less( + diff, tol + tol * math_ops.abs(covariance_matrix), message="Matrix was not symmetric.") covariance_matrix = control_flow_ops.with_dependencies( [assert_symmetric], covariance_matrix) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 300bdd5f6064a1cc9c336689ac4fae04338edb30..a7399792892f4c179c05168184d76ec95c168b51 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -90,8 +90,7 @@ class MultivariateNormalLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -103,9 +102,9 @@ class MultivariateNormalLinearOperator( # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale)) + scale=tf.linalg.LinearOperatorLowerTriangular(scale)) # Covariance agrees with cholesky(cov) parameterization. mvn.covariance().eval() @@ -122,9 +121,9 @@ class MultivariateNormalLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index 260dcc18f513d5440d3d39368539274c03faa72a..6c7dc4ca7aaf5b3a20b072e9360d15528ad10556 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -76,12 +76,13 @@ class MultivariateNormalTriL( ``` Trainable (batch) lower-triangular matrices can be created with - `ds.matrix_diag_transform()` and/or `ds.fill_triangular()` + `tf.contrib.distributions.matrix_diag_transform()` and/or + `tf.contrib.distributions.fill_triangular()` #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -92,7 +93,7 @@ class MultivariateNormalTriL( # ==> [[ 0.6, 0. , 0. ], # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=scale) @@ -112,7 +113,7 @@ class MultivariateNormalTriL( mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] tril = ... # shape: [2, 3, 3], lower triangular, non-zero diagonal. - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=tril) @@ -124,9 +125,9 @@ class MultivariateNormalTriL( # Instantiate a "learnable" MVN. dims = 4 with tf.variable_scope("model"): - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=tf.get_variable(shape=[dims], dtype=tf.float32, name="mu"), - scale_tril=ds.fill_triangular( + scale_tril=tfd.fill_triangular( tf.get_variable(shape=[dims * (dims + 1) / 2], dtype=tf.float32, name="chol_Sigma"))) ``` diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 8a95038a3c8eccf8a75fea79d0a62f9883b4f13a..2701c36fb53b1ae3fd736be3b1288e3dd40c739a 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -107,10 +107,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions + # Create two batches of PoissonLogNormalQuadratureCompounds, one with # prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.` - pln = ds.PoissonLogNormalQuadratureCompound( + pln = tfd.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., quadrature_grid_and_probs=( @@ -292,7 +293,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): # where, # # Z|v ~ interpolate_affine[v](distribution) - # V ~ mixture_distrubution + # V ~ mixture_distribution # # thus, # diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index b05f15771a3a94779ffddea8f16ad2fa4ea2fdd1..c4b8f055b7fbc3f0835b503eddd7617610326d8c 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -115,7 +115,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. - Default is `ds.Normal(0., 1.)`. + Default is `tf.distributions.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 92043d6a08833888c36009261addca0d14949ea8..904724af429f3cb5835f6e05abcb574467ef6918 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -188,8 +188,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and # another with mix_loc=[1]. In both cases, `K=2` and the affine @@ -197,20 +196,20 @@ class VectorDiffeomixture(distribution_lib.Distribution): # k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity # k=1: loc=[2.]*dims scale=LinOpDiag dims = 5 - vdm = ds.VectorDiffeomixture( + vdm = tfd.VectorDiffeomixture( mix_loc=[[0.], [1]], mix_scale=[1.], - distribution=ds.Normal(loc=0., scale=1.), + distribution=tfd.Normal(loc=0., scale=1.), loc=[ None, # Equivalent to `np.zeros(dims, dtype=np.float32)`. np.float32([2.]*dims), ], scale=[ - la.LinearOperatorScaledIdentity( + tf.linalg.LinearOperatorScaledIdentity( num_rows=dims, multiplier=np.float32(1.1), is_positive_definite=True), - la.LinearOperatorDiag( + tf.linalg.LinearOperatorDiag( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index 356d78b67a8107750f68f7f84d73d1231f5b2b03..526fe2d39aef9aed833b889de80e849c469435e7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -89,14 +89,13 @@ class VectorExponentialDiag( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. # The first component has pdf exp{-x}, the second 0.5 exp{-x / 2} - vex = ds.VectorExponentialDiag(scale_diag=[1., 2.]) + vex = tfd.VectorExponentialDiag(scale_diag=[1., 2.]) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([3., 4.]).eval() # shape: [] @@ -107,7 +106,7 @@ class VectorExponentialDiag( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialDiag(loc, scale_diag) + vex = tfd.VectorExponentialDiag(loc, scale_diag) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index b313a851b381e5b3a057fd17e6c2ef4eb0fc34f1..9d5fd9ac4178a1ae29b1ce32f304b22fd3d234dc 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -107,16 +107,15 @@ class VectorExponentialLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. mat = [[1.0, 0.1], [0.1, 1.0]] - vex = ds.VectorExponentialLinearOperator( - scale=la.LinearOperatorFullMatrix(mat)) + vex = tfd.VectorExponentialLinearOperator( + scale=tf.linalg.LinearOperatorFullMatrix(mat)) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([1., 2.]).eval() # shape: [] @@ -127,9 +126,9 @@ class VectorExponentialLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialLinearOperator( + vex = tfd.VectorExponentialLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 0e3867809a820f49cfa7f5282c47f786626481a6..8dd983b750d9b39775e570800006011f4968f7f3 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -101,10 +101,10 @@ class VectorLaplaceDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -118,7 +118,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -136,7 +136,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate VectorLaplace's. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index c7abdbb4caf9bee4cbd5991eb5d652f20dd0f8d1..ec485c95c15da2794b67d2699d2bdd9db97bb6c4 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -109,8 +109,7 @@ class VectorLaplaceLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate VectorLaplace with some desired covariance. mu = [1., 2, 3] @@ -124,9 +123,9 @@ class VectorLaplaceLinearOperator( # [ 0.1, -0.3, 0.4]]) # Divide scale by sqrt(2) so that the final covariance will be what we want. - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale / tf.sqrt(2))) + scale=tf.linalg.LinearOperatorLowerTriangular(scale / tf.sqrt(2.))) # Covariance agrees with cholesky(cov) parameterization. vla.covariance().eval() @@ -143,9 +142,9 @@ class VectorLaplaceLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 544a8710709a0afb56c6ae6f36d35de892e8e420..e1ccf116457a97261b9ce3965552764771d3bdd2 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -143,7 +143,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): broadcastable with `event_shape`. distribution: `tf.Distribution`-like instance. Distribution from which `k` iid samples are used as input to transformation `F`. Default is - `ds.Normal(0., 1.)`. + `tf.distributions.Normal(loc=0., scale=1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 29d41ab81c62d621c3c3533e1449341e9a085645..8c67647a618d22a58428d78865c4ebf7d98bdf9e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -91,14 +91,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): Extra leading dimensions, if provided, allow for batches. ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate vector Student's t-distribution. mu = [1., 2, 3] chol = [[1., 0, 0.], [1, 3, 0], [1, 2, 3]] - vt = ds.VectorStudentT(df=2, loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(df=2, loc=mu, scale_tril=chol) # Evaluate this on an observation in R^3, returning a scalar. vt.prob([-1., 0, 1]) @@ -107,7 +107,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): mu = [[1., 2, 3], [11, 22, 33]] chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal. - vt = ds.VectorStudentT(loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(loc=mu, scale_tril=chol) # Evaluate this on a two observations, each in R^3, returning a length two # tensor. diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index ae4b07799f5c123b68529443a1765fbfbac05492..09242ee47ddd044dfc99e22d5b7751a989c86485 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -1,4 +1,4 @@ -# TensorFlow Eager Execution +# Eager Execution > *WARNING*: This is a preview/pre-alpha version. The API and performance > characteristics are subject to change. @@ -76,3 +76,6 @@ For an introduction to eager execution in TensorFlow, see: ## Changelog - 2017/10/31: Initial preview release. +- 2017/12/01: Example of dynamic neural network: + [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021). + See [README.md](python/examples/spinn/README.md) for details. diff --git a/tensorflow/contrib/eager/proto/BUILD b/tensorflow/contrib/eager/proto/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..aedfec8924e7314addd22349c0576a84a58d9aa3 --- /dev/null +++ b/tensorflow/contrib/eager/proto/BUILD @@ -0,0 +1,24 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_proto_library( + name = "checkpointable_object_graph_proto", + srcs = [ + "checkpointable_object_graph.proto", + ], + visibility = ["//tensorflow/contrib/eager/python:__subpackages__"], +) diff --git a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto new file mode 100644 index 0000000000000000000000000000000000000000..c962638aa11c06dcd5be6a794314e029ae84e572 --- /dev/null +++ b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; + +package tensorflow.contrib.eager; + +// Prototype for an addition to BundleHeaderProto which saves extra information +// about the objects which own variables, allowing for more robust checkpoint +// loading into modified programs. + +message CheckpointableObjectGraph { + message Object { + message ObjectReference { + // An index into `CheckpointableObjectGraph.nodes`, indicating the object + // being referenced. + int32 node_id = 1; + // A numeric identifier for this object within its parent. + int32 local_uid = 2; + // A user-provided name for the edge. May be blank/omitted, in which case + // there is no explicitly provided local name; fall back on local_uid. + string local_name = 3; + } + + message VariableReference { + // A name for the variable which is unique within the object which owns + // it. Does not include a name_scope or variable_scope prefix. + string local_name = 1; + // The full name of the variable. Used to allow name-based loading of + // checkpoints which were saved using an object-based API. + string full_name = 2; + } + + message SlotVariableReference { + // An index into `CheckpointableObjectGraph.nodes`, indicating the object + // which created the variable that this variable is slotting for. + int32 original_variable_node_id = 1; + // The local name of the variable being slotted for within the object that + // owns it. + string original_variable_local_name = 2; + // The name of the slot (e.g. "m"/"v"). + string slot_name = 3; + // The full name of the slot variable. Used to allow name-based loading of + // checkpoints which were saved using an object-based API. + string full_name = 4; + } + + // Objects which this object depends on. + repeated ObjectReference children = 1; + // Non-slot variables owned by this object. + repeated VariableReference variables = 2; + // Slot variables owned by this object. + repeated SlotVariableReference slot_variables = 3; + } + + repeated Object nodes = 1; +} diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 2b84bc2e9b7453fac99ea2becc328ca854cf555d..086315464c99811371d836aed290b5068729adb0 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -12,16 +12,16 @@ py_library( visibility = ["//visibility:public"], deps = [ ":datasets", - ":evaluator", ":metrics", ":network", ":saver", - ":summary_writer", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:numerics", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:script_ops", "//tensorflow/python:util", + "//tensorflow/python:variable_scope", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:core", @@ -51,21 +51,22 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/eager:context", ], ) -py_test( +cuda_py_test( name = "datasets_test", srcs = ["datasets_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":datasets", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -103,37 +104,6 @@ cuda_py_test( ], ) -py_library( - name = "summary_writer", - srcs = ["summary_writer.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/summary:gen_summary_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary_op_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - ], -) - -cuda_py_test( - name = "summary_writer_test", - srcs = ["summary_writer_test.py"], - additional_deps = [ - ":summary_writer", - "//third_party/py/numpy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:constant_op", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - ], -) - py_library( name = "metrics", srcs = [ @@ -165,11 +135,9 @@ py_test( ":metrics", "//tensorflow/contrib/summary:summary_ops", "//tensorflow/contrib/summary:summary_test_util", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", - "//tensorflow/python:lib", - "//tensorflow/python:platform", + "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -219,8 +187,11 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/python:framework_ops", "//tensorflow/python:layers_base", + "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", "//tensorflow/python/estimator:util", ], ) @@ -231,17 +202,54 @@ py_test( srcs_version = "PY2AND3", deps = [ ":network", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:constant_op", + "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:function", "//tensorflow/python/eager:test", ], ) +py_library( + name = "checkpointable", + srcs = ["checkpointable.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/eager/proto:checkpointable_object_graph_proto_py", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "checkpointable_test", + srcs = ["checkpointable_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":checkpointable", + ":network", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:layers", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "@six_archive//:six", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/eager/python/checkpointable.py b/tensorflow/contrib/eager/python/checkpointable.py new file mode 100644 index 0000000000000000000000000000000000000000..b141ffb2bc03b8e38f8481bc044c3aae7e156c15 --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable.py @@ -0,0 +1,392 @@ +"""An object-local variable management scheme.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re + +from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.training import saver as saver_lib + +_CheckpointableReference = collections.namedtuple( + "_CheckpointableReference", + [ + "name", # The local name if explicitly specified, else None. + "local_uid", # 0 for the first dependency, 1 for the next, ... Used for + # routing checkpointed variables to their correct + # Checkpointables when "name" is not set (see docstring of + # `track_checkpointable`). + "ref" # The Checkpointable object being referenced. + ]) + +_OwnedVariable = collections.namedtuple( + "_OwnedVariable", + [ + "name", # The variable's (local) name. + "variable" # The owned variable object. + ]) + +# Validation regular expression for the local names of Checkpointable +# objects. In particular, disallows "/" in names, and reserves +# underscore-prefixed names. +_VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.-]*$") + +# Keyword for identifying that the next bit of a checkpoint variable name is a +# slot name. May not be the local name of a checkpointable. Checkpoint names for +# slot variables look like: +# +# /<_OPTIMIZER_SLOTS_NAME>// +# +# Where is a full path from the checkpoint root to the +# variable being slotted for. +_OPTIMIZER_SLOTS_NAME = "_OPTIMIZER_SLOT" + + +class Checkpointable(object): + """Manages variables and dependencies on other objects. + + To make reliable checkpoints, all `Checkpointable`s on which this object + depends must be registered in the constructor using `track_checkpointable` in + a deterministic order, and if possible they should be named. Variables may be + created using `add_variable` outside of the constructor and in any order, but + only these variables will be saved. + """ + + def __init__(self): + # Basically less useful OrderedDicts but without the reference cycles. + # TODO(allenl): Switch these to OrderedDict once TensorFlow supports only + # Python 3.6+. + self._checkpoint_dependencies = [] # A list of _CheckpointableReference + # objects. + self._dependency_names = set() + self._owned_variables = [] # A list of _OwnedVariable objects. + self._owned_variable_names = set() + + def add_variable(self, name, shape, dtype=None, initializer=None, **kwargs): + """Create a new variable object to be saved with this `Checkpointable`. + + If the user has requested that this object or another `Checkpointable` which + depends on this object be restored from a checkpoint (deferred loading + before variable object creation), `initializer` may be ignored and the value + from the checkpoint used instead. + + Args: + name: A name for the variable. Must be unique within this object. + shape: The shape of the variable. + dtype: The data type of the variable. + initializer: The initializer to use. Ignored if deferred loading has been + requested. + **kwargs: Passed to get_variable. + + Returns: + The new variable object. + + Raises: + ValueError: If the variable name is not unique. + """ + if name in self._owned_variable_names: + raise ValueError( + ("A variable named '%s' already exists in this Checkpointable, but " + "Checkpointable.add_variable called to create another with " + "that name. Variable names must be unique within a Checkpointable " + "object.") % (name,)) + if "getter" in kwargs: + # Allow the getter to be overridden, typically because there is a need for + # compatibility with some other variable creation mechanism. This should + # be relatively uncommon in user code. + getter = kwargs.pop("getter") + else: + getter = variable_scope.get_variable + # TODO(allenl): handle deferred loading + new_variable = getter( + name=name, shape=shape, dtype=dtype, initializer=initializer, **kwargs) + self._owned_variables.append( + _OwnedVariable(name=name, variable=new_variable)) + self._owned_variable_names.add(name) + return new_variable + + def track_checkpointable(self, checkpointable, name=None): + """Declare a dependency on another `Checkpointable` object. + + Indicates that checkpoints for this object should include variables from + `checkpointable`. + + Variables in a checkpoint are mapped to `Checkpointable`s based on names if + provided when the checkpoint was written, but otherwise use the order those + `Checkpointable`s were declared as dependencies. Both `name` arguments and + the dependency declaration order should be deterministic. + + There are two sufficient conditions to avoid breaking existing checkpoints + when modifying a class: (1) New dependencies must be declared after existing + dependencies, and (2) dependencies which were previously declared may never + be removed (a trivial placeholder with the same name may be used instead). + + Args: + checkpointable: A `Checkpointable` which this object depends on. + name: A local name for `checkpointable`, used for loading checkpoints into + the correct objects. If provided, it must be unique within this + `Checkpointable`. If None, dependency declaration order is used instead. + + Returns: + `checkpointable`, for convenience when declaring a dependency and + assigning to a member variable in one statement. + + Raises: + RuntimeError: If __init__ was not called. + TypeError: If `checkpointable` does not inherit from `Checkpointable`. + ValueError: For invalid names. + """ + if not hasattr(self, "_checkpoint_dependencies"): + raise RuntimeError("Need to call Checkpointable.__init__ before calling " + "Checkpointable.track_checkpointable().") + if not isinstance(checkpointable, Checkpointable): + raise TypeError( + ("Checkpointable.track_checkpointable() passed type %s, not a " + "Checkpointable.") % (type(checkpointable),)) + if name is not None: + if not _VALID_LOCAL_NAME.match(name): + raise ValueError( + ("Checkpointable names must match the regular expression '%s', but " + "got an invalid name '%s' instead.") % (_VALID_LOCAL_NAME.pattern, + name)) + if name in self._dependency_names: + raise ValueError( + ("Called Checkpointable.track_checkpointable() with name='%s', but " + "a Checkpointable with this name is already declared as a " + "dependency. If provided, names must be unique.") % (name,)) + self._dependency_names.add(name) + self._checkpoint_dependencies.append( + _CheckpointableReference( + name=name, + ref=checkpointable, + # TODO(allenl): Should this be exposed to allow users to stop + # depending on things and still load checkpoints when not using + # names? + local_uid=len(self._checkpoint_dependencies))) + return checkpointable + + @property + def checkpoint_dependencies(self): + """Other `Checkpointable` objects on which this object depends.""" + return self._checkpoint_dependencies + + +def _breadth_first_checkpointable_traversal(root_checkpointable): + """Find shortest paths to all variables owned by dependencies of root.""" + bfs_sorted = [] + root_checkpointable_reference = _CheckpointableReference( + name=None, local_uid=0, ref=root_checkpointable) + to_visit = collections.deque([root_checkpointable_reference]) + path_to_root = {root_checkpointable_reference: ()} + while to_visit: + current_checkpointable = to_visit.popleft() + bfs_sorted.append(current_checkpointable) + for child_checkpointable in ( + current_checkpointable.ref.checkpoint_dependencies): + if child_checkpointable not in path_to_root: + path_to_root[child_checkpointable] = ( + path_to_root[current_checkpointable] + (child_checkpointable,)) + to_visit.append(child_checkpointable) + return bfs_sorted, path_to_root + + +def _object_prefix_from_path(path_to_root): + return "/".join((checkpointable.name if checkpointable.name else "_%d" % ( + checkpointable.local_uid,)) for checkpointable in path_to_root) + + +def _escape_variable_name(variable_name): + # We need to support slashes in variable names for compatibility, since this + # naming scheme is being patched in to things like Layer.add_variable where + # slashes were previously accepted. We also want to use slashes to indicate + # edges traversed to reach the variable, so we escape forward slashes in + # variable names. + return variable_name.replace("_S_", "_S_.").replace(r"/", r"_S__") + + +def _variable_naming_for_object(path_to_root): + """Make a function for naming variables in an object.""" + # Name non-slot variables: + # + # / + # + # is not necessarily unique, but this is fine since we also + # save the graph of `Checkpointable`s with the checkpoint. Even if this path + # no longer exists because of a change in the Python program, we can look up + # the `Checkpointable` which owns the variable in the checkpoint's graph and + # use another path if one still exists. + + object_prefix = _object_prefix_from_path(path_to_root) + if object_prefix: + object_prefix += "/" + + def _name_single_variable(owned_variable): + """Names a variable within an object.""" + return object_prefix + _escape_variable_name(owned_variable.name) + + return _name_single_variable + + +def _slot_variable_naming_for_optimizer(optimizer, path_to_root): + """Make a function for naming slot variables in an optimizer.""" + # Name slot variables: + # + # /<_OPTIMIZER_SLOTS_NAME>// + # + # where is exactly the checkpoint name used for the original + # variable, including the path from the checkpoint root and the local name in + # the object which owns it. Note that we only save slot variables if the + # variable it's slotting for is also being saved. + + optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, + _object_prefix_from_path(path_to_root)) + + def _name_slot_variable(variable_path, slot_name): + """With an optimizer specified, name a slot variable.""" + + if not _VALID_LOCAL_NAME.match(slot_name): + # Slot variable names include the name of the slot. We need to + # validate that part of the name to be sure that the checkpoint name + # is a valid name scope name. + raise ValueError( + ("Could not save slot variables for optimizer %s, because its " + "slot name has invalid characters (got '%s', was expecting it " + "to match the regular expression '%s').") % + (optimizer, slot_name, _VALID_LOCAL_NAME.pattern)) + + return variable_path + optimizer_identifier + slot_name + + return _name_slot_variable + + +def _serialize_non_slot_variables(checkpointable_objects, path_to_root, + object_graph_proto): + """Name non-slot variables and add them to `object_graph_proto`.""" + named_variables = {} + non_slot_variables = [] + checkpoint_node_ids = {} + + for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + checkpoint_node_ids[checkpointable] = checkpoint_id + + for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + naming_scheme = _variable_naming_for_object(path_to_root[checkpointable]) + object_proto = object_graph_proto.nodes.add() + for owned_variable in checkpointable.ref._owned_variables: # pylint: disable=protected-access + variable_name = naming_scheme(owned_variable) + named_variables[variable_name] = owned_variable.variable + non_slot_variables.append(( + variable_name, # The variable's full checkpoint name + owned_variable, # The variable's _OwnedVariable object + checkpoint_id)) # The checkpoint ID of the node which owns this + # variable. + variable_proto = object_proto.variables.add() + variable_proto.local_name = owned_variable.name + # Figure out the name-based Saver's name for this variable. + saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + [owned_variable.variable], convert_variable_to_tensor=False) + variable_full_name, = saver_dict.keys() + variable_proto.full_name = variable_full_name + + for child in checkpointable.ref.checkpoint_dependencies: + child_proto = object_proto.children.add() + child_proto.node_id = checkpoint_node_ids[child] + child_proto.local_uid = child.local_uid + if child.name is not None: + child_proto.local_name = child.name + return named_variables, non_slot_variables + + +def _serialize_slot_variables(checkpointable_objects, path_to_root, + non_slot_variables, object_graph_proto): + """Name slot variables and add them to `object_graph_proto`.""" + named_slot_variables = {} + for optimizer_checkpoint_id, checkpointable_ref in enumerate( + checkpointable_objects): + if isinstance(checkpointable_ref.ref, optimizer_lib.Optimizer): + optimizer_object_proto = object_graph_proto.nodes[optimizer_checkpoint_id] + naming_scheme = _slot_variable_naming_for_optimizer( + optimizer=checkpointable_ref.ref, + path_to_root=path_to_root[checkpointable_ref]) + slot_names = checkpointable_ref.ref.get_slot_names() + for (variable_path, owned_variable, + original_node_checkpoint_id) in non_slot_variables: + for slot_name in slot_names: + slot_variable = checkpointable_ref.ref.get_slot( + owned_variable.variable, slot_name) + if slot_variable is not None: + checkpoint_name = naming_scheme( + variable_path=variable_path, slot_name=slot_name) + named_slot_variables[checkpoint_name] = slot_variable + slot_variable_proto = optimizer_object_proto.slot_variables.add() + slot_variable_proto.slot_name = slot_name + # Figure out the name-based Saver's name for this variable. + saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + [slot_variable], convert_variable_to_tensor=False) + slot_variable_full_name, = saver_dict.keys() + slot_variable_proto.full_name = slot_variable_full_name + slot_variable_proto.original_variable_local_name = ( + owned_variable.name) + slot_variable_proto.original_variable_node_id = ( + original_node_checkpoint_id) + return named_slot_variables + + +# TODO(allenl): Convenience utility for saving multiple objects (i.e. construct +# a root Checkpointable if passed a list of Checkpointables). +def _serialize_object_graph(root_checkpointable): + """Determine checkpoint keys for variables and build a serialized graph. + + Non-slot variables are keyed based on a shortest path from the root saveable + to the object which owns the variable (i.e. the one which called + `Checkpointable.add_variable` to create it). + + Slot variables are keyed based on a shortest path to the variable being + slotted for, a shortest path to their optimizer, and the slot name. + + Args: + root_checkpointable: A `Checkpointable` object whose variables (including + the variables of dependencies, recursively) should be saved. + + Returns: + A tuple of (named_variables, object_graph_proto): + named_variables: A dictionary mapping names to variable objects. + object_graph_proto: A CheckpointableObjectGraph protocol buffer containing + the serialized object graph and variable references. + + Raises: + ValueError: If there are invalid characters in an optimizer's slot names. + """ + checkpointable_objects, path_to_root = ( + _breadth_first_checkpointable_traversal(root_checkpointable)) + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + + # Gather non-slot variables. + named_variables, non_slot_variables = _serialize_non_slot_variables( + checkpointable_objects, path_to_root, object_graph_proto) + + # Gather slot variables which are associated with variables gathered above. + named_slot_variables = _serialize_slot_variables( + checkpointable_objects, path_to_root, non_slot_variables, + object_graph_proto) + + named_variables.update(named_slot_variables) + return named_variables, object_graph_proto diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f820990bbe5fe6c9b4cdf890680aaad0847010c0 --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable_test.py @@ -0,0 +1,277 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import six + +from tensorflow.contrib.eager.python import checkpointable +from tensorflow.contrib.eager.python import network as network_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.layers import core +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import adam +from tensorflow.python.training import training_util + + +class CheckpointableDenseLayer(core.Dense, checkpointable.Checkpointable): + + def __init__(self, *args, **kwargs): + checkpointable.Checkpointable.__init__(self) + core.Dense.__init__(self, *args, **kwargs) + + def add_variable(self, name, shape, **kwargs): + # Calls both Checkpointable.add_variable and Layer.add_variable. Eventually + # Layer.add_variable should inherit from Checkpointable and simply call + # super and then do post-processing. + return checkpointable.Checkpointable.add_variable( + self, + name=name, + shape=shape, + getter=functools.partial(core.Dense.add_variable, self), + **kwargs) + + +# pylint: disable=not-callable +class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): + + def __init__(self): + network_lib.Network.__init__(self) + checkpointable.Checkpointable.__init__(self) + + def track_layer(self, layer, name=None): + self.track_checkpointable(layer, name=name) + return super(CheckpointableNetwork, self).track_layer(layer) + + +class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): + + def __init__(self, *args, **kwargs): + checkpointable.Checkpointable.__init__(self) + adam.AdamOptimizer.__init__(self, *args, **kwargs) + + # NOTE: Copied from AdamOptimizer with modifications to use add_variable + # for non-slot variables. These contortions are necessary to maintain + # checkpoint compatibility with variable.name based saving. + def _create_slots(self, var_list): + # Create the beta1 and beta2 accumulators on the same device as the first + # variable. Sort the var_list to make sure this device is consistent across + # workers (these need to go on the same PS, otherwise some updates are + # silently ignored). + first_var = min(var_list, key=lambda x: x.name) + + create_new = self._beta1_power is None + if not create_new and context.in_graph_mode(): + create_new = (self._beta1_power.graph is not first_var.graph) + + if create_new: + with ops.colocate_with(first_var): + + def _variable_getter(name, shape, dtype, initializer): + del shape, dtype # not used, but there for compatibility + return variable_scope.variable( + name=name, initial_value=initializer, trainable=False) + + self._beta1_power = self.add_variable( + name="beta1_power", + shape=[], + initializer=self._beta1, + getter=_variable_getter) + self._beta2_power = self.add_variable( + name="beta2_power", + shape=[], + initializer=self._beta2, + getter=_variable_getter) + # Create slots for the first and second moments. + for v in var_list: + self._zeros_slot(v, "m", self._name) + self._zeros_slot(v, "v", self._name) + + # TODO(allenl): Override slot variable creation (_get_or_make_slot, + # _get_or_make_slot_with_initializer, _zeros_slot) to allow deferred + # loading. Likely no need to run this through add_variable, since gathering + # slot variables is special cased anyway. + + +class MyNetwork(CheckpointableNetwork): + """A concrete Network for testing.""" + + def __init__(self): + super(MyNetwork, self).__init__() + self._named = self.track_layer( + CheckpointableDenseLayer(1, use_bias=True), name="named_dense") + self._unnamed = self.track_layer( + CheckpointableDenseLayer(1, use_bias=False)) + + def call(self, values): + return self._unnamed(self._named(values)) + + +class Root(checkpointable.Checkpointable): + """A stand-in for a Trainer class.""" + + def __init__(self, optimizer, network): + super(Root, self).__init__() + self.track_checkpointable(optimizer, name="optimizer") + self.track_checkpointable(network, name="network") + self._global_step = None + + @property + def global_step(self): + if self._global_step is None: + # Get the default create_global_step utility to actually call + # self.add_variable, by setting a custom getter. + def _owned_variable_as_custom_getter(getter, *args, **kwargs): + return self.add_variable(*args, getter=getter, **kwargs) + + with variable_scope.variable_scope( + "", custom_getter=_owned_variable_as_custom_getter): + self._global_step = training_util.create_global_step() + return self._global_step + + +class CheckpointNamingTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNamingWithOptimizer(self): + input_value = constant_op.constant([[3.]]) + network = MyNetwork() + # A nuisance Network using the same optimizer. Its slot variables should not + # go in the checkpoint, since it is never depended on. + other_network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root_checkpointable = Root(optimizer=optimizer, network=network) + if context.in_eager_mode(): + optimizer.minimize( + lambda: network(input_value), + global_step=root_checkpointable.global_step) + optimizer.minimize( + lambda: other_network(input_value), + global_step=root_checkpointable.global_step) + else: + train_op = optimizer.minimize( + network(input_value), global_step=root_checkpointable.global_step) + optimizer.minimize( + other_network(input_value), + global_step=root_checkpointable.global_step) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(train_op) + named_variables, serialized_graph = checkpointable._serialize_object_graph( + root_checkpointable) + expected_checkpoint_names = ( + # Created in the root node, so no prefix. + "global_step", + # No name provided to track_checkpointable(), so the position (1, after + # the named track_checkpointable() which is 0) is used instead. + "network/_1/kernel", + # track_checkpointable() with a name provided, so that's used + "network/named_dense/kernel", + "network/named_dense/bias", + # The optimizer creates two non-slot variables + "optimizer/beta1_power", + "optimizer/beta2_power", + # Slot variables + "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/m", + "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/v", + "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/m", + "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/v", + "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/m", + "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/v", + ) + six.assertCountEqual(self, expected_checkpoint_names, + named_variables.keys()) + # Check that we've mapped to the right variable objects (not exhaustive) + self.assertEqual("global_step:0", named_variables["global_step"].name) + self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0", + named_variables["network/_1/kernel"].name) + self.assertEqual("my_network/checkpointable_dense_layer/kernel:0", + named_variables["network/named_dense/kernel"].name) + self.assertEqual("beta1_power:0", + named_variables["optimizer/beta1_power"].name) + self.assertEqual("beta2_power:0", + named_variables["optimizer/beta2_power"].name) + # Spot check the generated protocol buffers. + self.assertEqual(0, serialized_graph.nodes[0].children[0].local_uid) + self.assertEqual("optimizer", + serialized_graph.nodes[0].children[0].local_name) + optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ + 0].node_id] + self.assertEqual("beta1_power", optimizer_node.variables[0].local_name) + self.assertEqual("beta1_power", optimizer_node.variables[0].full_name) + self.assertEqual( + "kernel", optimizer_node.slot_variables[0].original_variable_local_name) + original_variable_owner = serialized_graph.nodes[ + optimizer_node.slot_variables[0].original_variable_node_id] + self.assertEqual("kernel", original_variable_owner.variables[0].local_name) + self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) + # We strip off the :0 suffix, as variable.name-based saving does. + self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam", + optimizer_node.slot_variables[0].full_name) + self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam:0", + optimizer.get_slot( + var=named_variables["network/named_dense/kernel"], + name="m").name) + + def _get_checkpoint_name(self, name): + root = checkpointable.Checkpointable() + with variable_scope.variable_scope("get_checkpoint_name"): + # Create the variable in a variable scope so that we get more relaxed + # naming rules (variables outside a scope may not start with "_", "/" or + # "-"). Since we don't use the scope part of the name, these cases are + # somewhat annoying. + root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64) + named_variables, _ = checkpointable._serialize_object_graph(root) + checkpoint_name, = named_variables.keys() + with ops.name_scope("root/" + checkpoint_name): + pass # Make sure we can use this as an op name if we prefix it. + return checkpoint_name + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableNameEscaping(self): + self.assertEqual(r"a_S__b_S__c", self._get_checkpoint_name(r"a/b/c")) + self.assertEqual(r"", self._get_checkpoint_name(r"")) + self.assertEqual(r"_S__", self._get_checkpoint_name(r"/")) + self.assertEqual(r"_S___S_._", self._get_checkpoint_name(r"/_S__")) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNumberedPath(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + root.track_checkpointable(leaf) + leaf.add_variable(name="v", shape=[]) + named_variables, _ = checkpointable._serialize_object_graph(root) + variable_name, = named_variables.keys() + self.assertEqual(r"_0/v", variable_name) + + @test_util.run_in_graph_and_eager_modes() + def testLocalNameValidation(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + with self.assertRaisesRegexp(ValueError, "invalid name"): + # Leading underscores are reserved, which avoids conflicts with + # un-named edges in paths and the optimizer slots identifier. + root.track_checkpointable(leaf, name="_12") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 98e6983658aed77277d87915ff26a8c676224503..b559cce6b12a809d671ce7855680063f02a4ac22 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -20,11 +20,15 @@ from __future__ import print_function import threading +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops @@ -32,12 +36,12 @@ _uid_counter = 0 _uid_lock = threading.Lock() -def _iterator_shared_name(): +def _generate_shared_name(prefix): with _uid_lock: global _uid_counter uid = _uid_counter _uid_counter += 1 - return "eager_iterator_{}".format(uid) + return "{}_{}".format(prefix, uid) class Iterator(object): @@ -72,11 +76,12 @@ class Iterator(object): with ops.device("/device:CPU:0"): ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access self._output_types = dataset.output_types + self._output_shapes = dataset.output_shapes self._flat_output_types = nest.flatten(dataset.output_types) self._flat_output_shapes = nest.flatten(dataset.output_shapes) self._resource = gen_dataset_ops.iterator( container="", - shared_name=_iterator_shared_name(), + shared_name=_generate_shared_name("eager_iterator"), output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) @@ -84,6 +89,35 @@ class Iterator(object): self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="/device:CPU:0") self._device = context.context().device_name + self._buffer_resource_handle = None + if not context.context().device_spec.device_type: + is_remote_device = False + else: + is_remote_device = context.context().device_spec.device_type != "CPU" + if is_remote_device: + with ops.device("/device:CPU:0"): + iter_string_handle = gen_dataset_ops.iterator_to_string_handle( + self._resource) + + @function.Defun(dtypes.string) + def remote_fn(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, self._output_types, self._output_shapes) + return remote_iterator.get_next() + + remote_fn.add_to_graph(None) + target = constant_op.constant("/device:CPU:0") + with ops.device(self._device): + self._buffer_resource_handle = prefetching_ops.function_buffering_resource( + string_arg=iter_string_handle, + f=remote_fn, + target_device=target, + buffer_size=10, + thread_pool_size=1, + container="", + shared_name=_generate_shared_name("function_buffer_resource")) + self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._buffer_resource_handle, handle_device=self._device) def __iter__(self): return self @@ -93,20 +127,20 @@ class Iterator(object): def next(self): """Return the next tf.Tensor from the dataset.""" - try: - # TODO(ashankar): Consider removing this ops.device() contextmanager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - with ops.device("/device:CPU:0"): - ret = gen_dataset_ops.iterator_get_next( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - except errors.OutOfRangeError: - raise StopIteration - # Copies tensors from CPU to the current device if necessary. - # TODO(rohanj): This should be replaced by the mechanism to have the - # runtime's threads copy tensors to the destination device. with ops.device(self._device): - ret = [array_ops.identity(x) for x in ret] + try: + if self._buffer_resource_handle is not None: + ret = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffer_resource_handle, + output_types=self._flat_output_types) + else: + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + ret = gen_dataset_ops.iterator_get_next( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + except errors.OutOfRangeError: + raise StopIteration return nest.pack_sequence_as(self._output_types, ret) diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index bd0ab02ecf7ae6025e08dde1c3ddc634db9255c1..3faaeef5903615ea122800a6690117dde682e830 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -110,7 +110,7 @@ class Evaluator(object): return self._all_metric_results() else: def f(): - with summary_ops.create_summary_file_writer( + with summary_ops.create_file_writer( summary_logdir).as_default(), summary_ops.always_record_summaries(): return self._all_metric_results() if context.in_eager_mode(): diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py index 02f82cb216983accc7bc2dfa20cbb1ee0b8d8d26..7d2274db9b051e604266074651f4cbd331f20f48 100644 --- a/tensorflow/contrib/eager/python/evaluator_test.py +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -87,7 +87,7 @@ class EvaluatorTest(test.TestCase): e.all_metric_results(logdir) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 6.0) @@ -136,7 +136,7 @@ class EvaluatorTest(test.TestCase): variables.global_variables_initializer().run() e.run_evaluation(init_op, call_op, results_op) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 6.0) diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index aa21a6ab994acf929890ecebc07a86cf7ebf97db..6aef010a2139c4cd2ae19c008aa21d4e3592ca98 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -11,5 +11,6 @@ py_library( "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index d0130ebd118dbaff4f0161c8b2528764c6103e02..7bc5007c5655bed81b5600ee283c35bd332a1ebe 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -85,7 +85,7 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): if logdir: # Support for TensorBoard summaries. Once training has started, use: # tensorboard --logdir= - summary_writer = tf.contrib.summary.create_summary_file_writer(logdir) + summary_writer = tf.contrib.summary.create_file_writer(logdir) # Training loop. for i, (xs, ys) in enumerate(tfe.Iterator(dataset)): diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py index bfb7d5a9002787f6544d383de58150661ac2bde3..bb121c7704b4772dde520ddc928a13c50ec8bb18 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py @@ -190,9 +190,9 @@ def main(_): else: train_dir = None test_dir = None - summary_writer = tf.contrib.summary.create_summary_file_writer( + summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=10000) - test_summary_writer = tf.contrib.summary.create_summary_file_writer( + test_summary_writer = tf.contrib.summary.create_file_writer( test_dir, flush_millis=10000, name='test') checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 736a75332ff6403ea1b21387211df6b8fb6034f3..23317886e712323f4b520000e0fd372734fc53a1 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -73,7 +73,7 @@ class ResNet50GraphTest(tf.test.TestCase): tf.train.get_or_create_global_step() logdir = tempfile.mkdtemp() with tf.contrib.summary.always_record_summaries(): - with tf.contrib.summary.create_summary_file_writer( + with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(): model = resnet50.ResNet50(data_format()) @@ -95,7 +95,7 @@ class ResNet50GraphTest(tf.test.TestCase): sess.run([train_op, tf.contrib.summary.all_summary_ops()], feed_dict={images: np_images, labels: np_labels}) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index d6389f2e385b3637b178d49fc56e8baf913eccaa..d8d8644dde10498e5fd480f92b69656fca1558dd 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -95,7 +95,7 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format) tf.train.get_or_create_global_step() logdir = tempfile.mkdtemp() - with tf.contrib.summary.create_summary_file_writer( + with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(), tf.contrib.summary.always_record_summaries(): with tf.device(device): @@ -103,7 +103,7 @@ class ResNet50Test(tf.test.TestCase): images, labels = random_batch(2) train_one_step(model, images, labels, optimizer) self.assertEqual(320, len(model.variables)) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD index b657d31f35bafd6624ac7e4d6a6f6b2db362649d..f83eb5c476ed9f45d70849a0de6c0f20973682a5 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD @@ -11,6 +11,7 @@ py_binary( deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/python/eager:context", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 609cbd28772c3ae8da70648ca5b1b264a8a255e2..40919f2d4cf511eb35fac954719286366aef6c7c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -247,9 +247,9 @@ def main(_): log_dir = os.path.join(FLAGS.dir, "summaries") tf.gfile.MakeDirs(log_dir) - train_summary_writer = tf.contrib.summary.create_summary_file_writer( + train_summary_writer = tf.contrib.summary.create_file_writer( os.path.join(log_dir, "train"), flush_millis=10000) - test_summary_writer = tf.contrib.summary.create_summary_file_writer( + test_summary_writer = tf.contrib.summary.create_file_writer( os.path.join(log_dir, "eval"), flush_millis=10000, name="eval") with tf.device(device): diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD index db2587bf2cb548ae37e58597691e96ae2c2e8177..4b4792cd49bf8bd4ad46a0371ef0d2f8a07ddd1c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD @@ -10,7 +10,9 @@ py_binary( srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", + "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a1f8a759e2a556bc219f0aa13942f293c4f34cfa --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -0,0 +1,42 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "data", + srcs = ["data.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = ["//third_party/py/numpy"], +) + +py_test( + name = "data_test", + size = "small", + srcs = ["data_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":data", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "spinn_test", + size = "medium", + srcs = ["spinn_test.py"], + additional_deps = [ + ":data", + "//third_party/examples/eager/spinn", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/summary:summary_test_util", + "//tensorflow/python/eager:test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], + tags = ["no_pip"], # because spinn.py is under third_party/. +) diff --git a/tensorflow/contrib/eager/python/examples/spinn/README.md b/tensorflow/contrib/eager/python/examples/spinn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eb0637df473e22e5d39ca1b0816464cb2b7c6435 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/README.md @@ -0,0 +1,13 @@ +# SPINN: Dynamic neural network with TensorFlow eager execution + +This directory contains files supporting the +[spinn.py model in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/spinn.py), +including + +- `data.py`: Utility library for loading and preprocessing the SNLI and GloVe + data. +- `data_test.py` and `spinn_test.py`: Unit tests for the data and model modules. + +See the [README.md in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/README.md) +for detailed background, license and usage information regarding the SPINN code. + diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e046320f78541bef4e091e97f08fd51857af83 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data.py @@ -0,0 +1,350 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities of SNLI data and GloVe word vectors for SPINN model. + +See more details about the SNLI data set at: + https://nlp.stanford.edu/projects/snli/ + +See more details about the GloVe pretrained word embeddings at: + https://nlp.stanford.edu/projects/glove/ +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import math +import os +import random + +import numpy as np + +POSSIBLE_LABELS = ("entailment", "contradiction", "neutral") + +UNK_CODE = 0 # Code for unknown word tokens. +PAD_CODE = 1 # Code for padding tokens. + +SHIFT_CODE = 3 +REDUCE_CODE = 2 + +WORD_VECTOR_LEN = 300 # Embedding dimensions. + +LEFT_PAREN = "(" +RIGHT_PAREN = ")" +PARENTHESES = (LEFT_PAREN, RIGHT_PAREN) + + +def get_non_parenthesis_words(items): + """Get the non-parenthesis items from a SNLI parsed sentence. + + Args: + items: Data items from a parsed SNLI setence, with parentheses. E.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of non-parenthis word items, all converted to lower case. E.g., + ["man", "wearing", "pass", ... + """ + return [x.lower() for x in items if x not in PARENTHESES and x] + + +def get_shift_reduce(items): + """Obtain shift-reduce vector from a list of items from the SNLI data. + + Args: + items: Data items as a list of str, e.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of shift-reduce transitions, encoded as `SHIFT_CODE` for shift and + `REDUCE_CODE` for reduce. See code above for the values of `SHIFT_CODE` + and `REDUCE_CODE`. + """ + trans = [] + for item in items: + if item == LEFT_PAREN: + continue + elif item == RIGHT_PAREN: + trans.append(REDUCE_CODE) + else: + trans.append(SHIFT_CODE) + return trans + + +def pad_and_reverse_word_ids(sentences): + """Pad a list of sentences to the common maximum length + 1. + + Args: + sentences: A list of sentences as a list of list of integers. Each integer + is a word ID. Each list of integer corresponds to one sentence. + + Returns: + A numpy.ndarray of shape (num_sentences, max_length + 1), wherein max_length + is the maximum sentence length (in # of words). Each sentence is reversed + and then padded with an extra one at head, as required by the model. + """ + max_len = max(len(sent) for sent in sentences) + for sent in sentences: + if len(sent) < max_len: + sent.extend([PAD_CODE] * (max_len - len(sent))) + # Reverse in time order and pad an extra one. + sentences = np.fliplr(np.array(sentences, dtype=np.int64)) + sentences = np.concatenate( + [np.ones([sentences.shape[0], 1], dtype=np.int64), sentences], axis=1) + return sentences + + +def pad_transitions(sentences_transitions): + """Pad a list of shift-reduce transitions to the maximum length.""" + max_len = max(len(transitions) for transitions in sentences_transitions) + for transitions in sentences_transitions: + if len(transitions) < max_len: + transitions.extend([PAD_CODE] * (max_len - len(transitions))) + return np.array(sentences_transitions, dtype=np.int64) + + +def load_vocabulary(data_root): + """Load vocabulary from SNLI data files. + + Args: + data_root: Root directory of the data. It is assumed that the SNLI data + files have been downloaded and extracted to the "snli/snli_1.0" + subdirectory of it. + + Returns: + Vocabulary as a set of strings. + + Raises: + ValueError: If SNLI data files cannot be found. + """ + snli_path = os.path.join(data_root, "snli") + snli_glob_pattern = os.path.join(snli_path, "snli_1.0/snli_1.0_*.txt") + file_names = glob.glob(snli_glob_pattern) + if not file_names: + raise ValueError( + "Cannot find SNLI data files at %s. " + "Please download and extract SNLI data first." % snli_glob_pattern) + + print("Loading vocabulary...") + vocab = set() + for file_name in file_names: + with open(os.path.join(snli_path, file_name), "rt") as f: + for i, line in enumerate(f): + if i == 0: + continue + items = line.split("\t") + premise_words = get_non_parenthesis_words(items[1].split(" ")) + hypothesis_words = get_non_parenthesis_words(items[2].split(" ")) + vocab.update(premise_words) + vocab.update(hypothesis_words) + return vocab + + +def load_word_vectors(data_root, vocab): + """Load GloVe word vectors for words present in the vocabulary. + + Args: + data_root: Data root directory. It is assumed that the GloVe file + has been downloaded and extracted at the "glove/" subdirectory of it. + vocab: A `set` of words, representing the vocabulary. + + Returns: + 1. word2index: A dict from lower-case word to row index in the embedding + matrix, i.e, `embed` below. + 2. embed: The embedding matrix as a float32 numpy array. Its shape is + [vocabulary_size, WORD_VECTOR_LEN]. vocabulary_size is len(vocab). + WORD_VECTOR_LEN is the embedding dimension (300). + + Raises: + ValueError: If GloVe embedding file cannot be found. + """ + glove_path = os.path.join(data_root, "glove/glove.42B.300d.txt") + if not os.path.isfile(glove_path): + raise ValueError( + "Cannot find GloVe embedding file at %s. " + "Please download and extract GloVe embeddings first." % glove_path) + + print("Loading word vectors...") + + word2index = dict() + embed = [] + + embed.append([0] * WORD_VECTOR_LEN) # + embed.append([0] * WORD_VECTOR_LEN) # + word2index[""] = UNK_CODE + word2index[""] = PAD_CODE + + with open(glove_path, "rt") as f: + for line in f: + items = line.split(" ") + word = items[0] + if word in vocab and word not in word2index: + word2index[word] = len(embed) + vector = np.array([float(item) for item in items[1:]]) + assert (WORD_VECTOR_LEN,) == vector.shape + embed.append(vector) + embed = np.array(embed, dtype=np.float32) + return word2index, embed + + +def calculate_bins(length2count, min_bin_size): + """Cacluate bin boundaries given a histogram of lengths and mininum bin size. + + Args: + length2count: A `dict` mapping length to sentence count. + min_bin_size: Minimum bin size in terms of total number of sentence pairs + in the bin. + + Returns: + A `list` representing the right bin boundaries, starting from the inclusive + right boundary of the first bin. For example, if the output is + [10, 20, 35], + it means there are three bins: [1, 10], [11, 20] and [21, 35]. + """ + bounds = [] + lengths = sorted(length2count.keys()) + cum_count = 0 + for length in lengths: + cum_count += length2count[length] + if cum_count >= min_bin_size: + bounds.append(length) + cum_count = 0 + if bounds[-1] != lengths[-1]: + bounds.append(lengths[-1]) + return bounds + + +class SnliData(object): + """A split of SNLI data.""" + + def __init__(self, data_file, word2index, sentence_len_limit=-1): + """SnliData constructor. + + Args: + data_file: Full path to the data file, e.g., + "/tmp/spinn-data/snli/snli_1.0/snli_1.0.train.txt" + word2index: A dict from lower-case word to row index in the embedding + matrix (see `load_word_vectors()` for details). + sentence_len_limit: Maximum allowed sentence length (# of words). + A value of <= 0 means unlimited. Sentences longer than this limit + are currently discarded, not truncated. + """ + + self._labels = [] + self._premises = [] + self._premise_transitions = [] + self._hypotheses = [] + self._hypothesis_transitions = [] + + with open(data_file, "rt") as f: + for i, line in enumerate(f): + if i == 0: + # Skip header line. + continue + items = line.split("\t") + if items[0] not in POSSIBLE_LABELS: + continue + + premise_items = items[1].split(" ") + hypothesis_items = items[2].split(" ") + premise_words = get_non_parenthesis_words(premise_items) + hypothesis_words = get_non_parenthesis_words(hypothesis_items) + + if (sentence_len_limit > 0 and + (len(premise_words) > sentence_len_limit or + len(hypothesis_words) > sentence_len_limit)): + # TODO(cais): Maybe truncate; do not discard. + continue + + premise_ids = [ + word2index.get(word, UNK_CODE) for word in premise_words] + hypothesis_ids = [ + word2index.get(word, UNK_CODE) for word in hypothesis_words] + + self._premises.append(premise_ids) + self._hypotheses.append(hypothesis_ids) + self._premise_transitions.append(get_shift_reduce(premise_items)) + self._hypothesis_transitions.append(get_shift_reduce(hypothesis_items)) + assert (len(self._premise_transitions[-1]) == + 2 * len(premise_words) - 1) + assert (len(self._hypothesis_transitions[-1]) == + 2 * len(hypothesis_words) - 1) + + self._labels.append(POSSIBLE_LABELS.index(items[0]) + 1) + + assert len(self._labels) == len(self._premises) + assert len(self._labels) == len(self._hypotheses) + assert len(self._labels) == len(self._premise_transitions) + assert len(self._labels) == len(self._hypothesis_transitions) + + def num_batches(self, batch_size): + """Calculate number of batches given batch size.""" + return int(math.ceil(len(self._labels) / batch_size)) + + def get_generator(self, batch_size): + """Obtain a generator for batched data. + + All examples of this SnliData object are randomly shuffled, sorted + according to the maximum sentence length of the premise and hypothesis + sentences in the pair, and batched. + + Args: + batch_size: Desired batch size. + + Returns: + A generator for data batches. The generator yields a 5-tuple: + label: An array of the shape (batch_size,). + premise: An array of the shape (max_premise_len, batch_size), wherein + max_premise_len is the maximum length of the (padded) premise + sentence in the batch. + premise_transitions: An array of the shape (2 * max_premise_len -3, + batch_size). + hypothesis: Same as `premise`, but for hypothesis sentences. + hypothesis_transitions: Same as `premise_transitions`, but for + hypothesis sentences. + All the elements of the 5-tuple have dtype `int64`. + """ + # Randomly shuffle examples. + zipped = list(zip( + self._labels, self._premises, self._premise_transitions, + self._hypotheses, self._hypothesis_transitions)) + random.shuffle(zipped) + # Then sort the examples by maximum of the premise and hypothesis sentence + # lengths in the pair. During training, the batches are expected to be + # shuffled. So it is okay to leave them sorted by max length here. + (labels, premises, premise_transitions, hypotheses, + hypothesis_transitions) = zip( + *sorted(zipped, key=lambda x: max(len(x[1]), len(x[3])))) + + def _generator(): + begin = 0 + while begin < len(labels): + # The sorting above and the batching here makes sure that sentences of + # similar max lengths are batched together, minimizing the inefficiency + # due to uneven max lengths. The sentences are batched differently in + # each call to get_generator() due to the shuffling before sotring + # above. The pad_and_reverse_word_ids() and pad_transitions() functions + # take care of any remaning unevenness of the max sentence lengths. + end = min(begin + batch_size, len(labels)) + # Transpose, because the SPINN model requires time-major, instead of + # batch-major. + yield (labels[begin:end], + pad_and_reverse_word_ids(premises[begin:end]).T, + pad_transitions(premise_transitions[begin:end]).T, + pad_and_reverse_word_ids(hypotheses[begin:end]).T, + pad_transitions(hypothesis_transitions[begin:end]).T) + begin = end + return _generator diff --git a/tensorflow/contrib/eager/python/examples/spinn/data_test.py b/tensorflow/contrib/eager/python/examples/spinn/data_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f0b37c5099e45b7e3b258b258c0a203c36b3b7 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data_test.py @@ -0,0 +1,243 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for SPINN data module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.spinn import data + + +class DataTest(tf.test.TestCase): + + def setUp(self): + super(DataTest, self).setUp() + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(DataTest, self).tearDown() + + def testGenNonParenthesisWords(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + ["man", "wearing", "pass", "on", "a", "lanyard", "and", "standing", + "in", "a", "crowd", "of", "people", "."], + data.get_non_parenthesis_words(seq_with_parse.split(" "))) + + def testGetShiftReduce(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + [3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2, + 3, 2, 2], data.get_shift_reduce(seq_with_parse.split(" "))) + + def testPadAndReverseWordIds(self): + id_sequences = [[0, 2, 3, 4, 5], + [6, 7, 8], + [9, 10, 11, 12, 13, 14, 15, 16]] + self.assertAllClose( + [[1, 1, 1, 1, 5, 4, 3, 2, 0], + [1, 1, 1, 1, 1, 1, 8, 7, 6], + [1, 16, 15, 14, 13, 12, 11, 10, 9]], + data.pad_and_reverse_word_ids(id_sequences)) + + def testPadTransitions(self): + unpadded = [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2]] + self.assertAllClose( + [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2, 1, 1]], + data.pad_transitions(unpadded)) + + def testCalculateBins(self): + length2count = { + 1: 10, + 2: 15, + 3: 25, + 4: 40, + 5: 35, + 6: 10} + self.assertEqual([2, 3, 4, 5, 6], + data.calculate_bins(length2count, 20)) + self.assertEqual([3, 4, 6], data.calculate_bins(length2count, 40)) + self.assertEqual([4, 6], data.calculate_bins(length2count, 60)) + + def testLoadVoacbulary(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + fake_dev_file = os.path.join(snli_1_0_dir, "snli_1.0_dev.txt") + os.makedirs(snli_1_0_dir) + + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo baz ) . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + with open(fake_dev_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Quux quuz ) ? )\t( ( Corge grault ) ! )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Quux quuz?\t.Corge grault!\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + self.assertSetEqual( + {".", "?", "!", "foo", "bar", "baz", "quux", "quuz", "corge", "grault"}, + vocab) + + def testLoadVoacbularyWithoutFileRaisesError(self): + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli/snli_1.0")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + def testLoadWordVectors(self): + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", ",", "foo", "bar", "baz"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = {"foo", "bar", "baz", "qux", "."} + # Notice that "qux" is not present in `words`. + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + self.assertEqual(6, len(word2index)) + self.assertEqual(0, word2index[""]) + self.assertEqual(1, word2index[""]) + self.assertEqual(2, word2index["."]) + self.assertEqual(3, word2index["foo"]) + self.assertEqual(4, word2index["bar"]) + self.assertEqual(5, word2index["baz"]) + self.assertEqual((6, data.WORD_VECTOR_LEN), embed.shape) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[0, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[1, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[2, :]) + self.assertAllClose([0.2] * data.WORD_VECTOR_LEN, embed[3, :]) + self.assertAllClose([0.3] * data.WORD_VECTOR_LEN, embed[4, :]) + self.assertAllClose([0.4] * data.WORD_VECTOR_LEN, embed[5, :]) + + def testLoadWordVectorsWithoutFileRaisesError(self): + vocab = {"foo", "bar", "baz", "qux", "."} + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + os.makedirs(os.path.join(self._temp_data_dir, "glove")) + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + def testSnliData(self): + """Unit test for SnliData objects.""" + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + + # Four sentences in total. + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + self.assertEqual(4, train_data.num_batches(1)) + self.assertEqual(2, train_data.num_batches(2)) + self.assertEqual(2, train_data.num_batches(3)) + self.assertEqual(1, train_data.num_batches(4)) + + generator = train_data.get_generator(2)() + for i in range(2): + label, prem, prem_trans, hypo, hypo_trans = next(generator) + self.assertEqual(2, len(label)) + self.assertEqual((4, 2), prem.shape) + self.assertEqual((5, 2), prem_trans.shape) + self.assertEqual((3, 2), hypo.shape) + self.assertEqual((3, 2), hypo_trans.shape) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..84e25cf81a2223800c47994b26d000caddee6b01 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -0,0 +1,409 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import gc +import glob +import os +import shutil +import tempfile +import time + +import numpy as np +import tensorflow as tf + +# pylint: disable=g-bad-import-order +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.spinn import data +from third_party.examples.eager.spinn import spinn +from tensorflow.contrib.summary import summary_test_util +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +# pylint: enable=g-bad-import-order + + +def _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size): + """Generate a fake batch of SNLI data for testing.""" + with tf.device("cpu:0"): + labels = tf.random_uniform([batch_size], minval=1, maxval=4, dtype=tf.int64) + prem = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + prem_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + hypo = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + hypo_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + if tfe.num_gpus(): + labels = labels.gpu() + prem = prem.gpu() + prem_trans = prem_trans.gpu() + hypo = hypo.gpu() + hypo_trans = hypo_trans.gpu() + return labels, prem, prem_trans, hypo, hypo_trans + + +def _test_spinn_config(d_embed, d_out, logdir=None): + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict", + "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp", + "d_out", "projection", "lr", "batch_size", "epochs", + "force_cpu", "logdir", "log_every", "dev_every", "save_every", + "lr_decay_every", "lr_decay_by"]) + return config_tuple( + d_hidden=d_embed, + d_proj=d_embed * 2, + d_tracker=8, + predict=False, + embed_dropout=0.1, + mlp_dropout=0.1, + n_mlp_layers=2, + d_mlp=32, + d_out=d_out, + projection=True, + lr=2e-2, + batch_size=2, + epochs=10, + force_cpu=False, + logdir=logdir, + log_every=1, + dev_every=2, + save_every=2, + lr_decay_every=1, + lr_decay_by=0.75) + + +class SpinnTest(test_util.TensorFlowTestCase): + + def setUp(self): + super(SpinnTest, self).setUp() + self._test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(SpinnTest, self).tearDown() + + def testBundle(self): + with tf.device(self._test_device): + lstm_iter = [np.array([[0, 1], [2, 3]], dtype=np.float32), + np.array([[0, -1], [-2, -3]], dtype=np.float32), + np.array([[0, 2], [4, 6]], dtype=np.float32), + np.array([[0, -2], [-4, -6]], dtype=np.float32)] + out = spinn._bundle(lstm_iter) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 2, 0, -2, 0, 4, 0, -4]]).T, + out[0].numpy()) + self.assertAllEqual(np.array([[1, 3, -1, -3, 2, 6, -2, -6]]).T, + out[1].numpy()) + + def testUnbunbdle(self): + with tf.device(self._test_device): + state = [np.array([[0, 1, 2], [3, 4, 5]], dtype=np.float32), + np.array([[0, -1, -2], [-3, -4, -5]], dtype=np.float32)] + out = spinn._unbundle(state) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 1, 2, 0, -1, -2]]), + out[0].numpy()) + self.assertAllEqual(np.array([[3, 4, 5, -3, -4, -5]]), + out[1].numpy()) + + def testReducer(self): + with tf.device(self._test_device): + batch_size = 3 + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + left_in = [] + right_in = [] + tracking = [] + for _ in range(batch_size): + left_in.append(tf.random_normal((1, size * 2))) + right_in.append(tf.random_normal((1, size * 2))) + tracking.append(tf.random_normal((1, tracker_size * 2))) + + out = reducer(left_in, right_in, tracking=tracking) + self.assertEqual(batch_size, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual((1, size * 2), out[0].shape) + + def testReduceTreeLSTM(self): + with tf.device(self._test_device): + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + lstm_in = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [0, -1, -2, -3, -4, -5, -6, -7, -8, -9]], + dtype=np.float32) + c1 = np.array([[0, 1], [2, 3]], dtype=np.float32) + c2 = np.array([[0, -1], [-2, -3]], dtype=np.float32) + + h, c = reducer._tree_lstm(c1, c2, lstm_in) + self.assertEqual(tf.float32, h.dtype) + self.assertEqual(tf.float32, c.dtype) + self.assertEqual((2, 2), h.shape) + self.assertEqual((2, 2), c.shape) + + def testTracker(self): + with tf.device(self._test_device): + batch_size = 2 + size = 10 + tracker_size = 8 + buffer_length = 18 + stack_size = 3 + + tracker = spinn.Tracker(tracker_size, False) + tracker.reset_state() + + # Create dummy inputs for testing. + bufs = [] + buf = [] + for _ in range(buffer_length): + buf.append(tf.random_normal((batch_size, size * 2))) + bufs.append(buf) + self.assertEqual(1, len(bufs)) + self.assertEqual(buffer_length, len(bufs[0])) + self.assertEqual((batch_size, size * 2), bufs[0][0].shape) + + stacks = [] + stack = [] + for _ in range(stack_size): + stack.append(tf.random_normal((batch_size, size * 2))) + stacks.append(stack) + self.assertEqual(1, len(stacks)) + self.assertEqual(3, len(stacks[0])) + self.assertEqual((batch_size, size * 2), stacks[0][0].shape) + + for _ in range(2): + out1, out2 = tracker(bufs, stacks) + self.assertIsNone(out2) + self.assertEqual(batch_size, len(out1)) + self.assertEqual(tf.float32, out1[0].dtype) + self.assertEqual((1, tracker_size * 2), out1[0].shape) + + self.assertEqual(tf.float32, tracker.state.c.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.c.shape) + self.assertEqual(tf.float32, tracker.state.h.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.h.shape) + + def testSPINN(self): + with tf.device(self._test_device): + embedding_dims = 10 + d_tracker = 8 + sequence_length = 15 + num_transitions = 27 + + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict"]) + config = config_tuple( + embedding_dims, embedding_dims * 2, d_tracker, False) + s = spinn.SPINN(config) + + # Create some fake data. + buffers = tf.random_normal((sequence_length, 1, config.d_proj)) + transitions = tf.constant( + [[3], [3], [2], [3], [3], [3], [2], [2], [2], [3], [3], [3], + [2], [3], [3], [2], [2], [3], [3], [3], [2], [2], [2], [2], + [3], [2], [2]], dtype=tf.int64) + self.assertEqual(tf.int64, transitions.dtype) + self.assertEqual((num_transitions, 1), transitions.shape) + + out = s(buffers, transitions, training=True) + self.assertEqual(tf.float32, out.dtype) + self.assertEqual((1, embedding_dims), out.shape) + + def testSNLIClassifierAndTrainer(self): + with tf.device(self._test_device): + vocab_size = 40 + batch_size = 2 + d_embed = 10 + sequence_length = 15 + d_out = 4 + + config = _test_spinn_config(d_embed, d_out) + + # Create fake embedding matrix. + embed = tf.random_normal((vocab_size, d_embed)) + + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + # Invoke model under non-training mode. + logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Invoke model under training model. + logits = model(prem, prem_trans, hypo, hypo_trans, training=True) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Calculate loss. + loss1 = trainer.loss(labels, logits) + self.assertEqual(tf.float32, loss1.dtype) + self.assertEqual((), loss1.shape) + + loss2, logits = trainer.train_batch( + labels, prem, prem_trans, hypo, hypo_trans) + self.assertEqual(tf.float32, loss2.dtype) + self.assertEqual((), loss2.shape) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + # Training on the batch should have led to a change in the loss value. + self.assertNotEqual(loss1.numpy(), loss2.numpy()) + + def testTrainSpinn(self): + """Test with fake toy SNLI data and GloVe vectors.""" + + # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + + # Four sentences in total. + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + dev_data = data.SnliData(fake_train_file, word2index) + test_data = data.SnliData(fake_train_file, word2index) + print(embed) + + # 2. Create a fake config. + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir")) + + # 3. Test training of a SPINN model. + spinn.train_spinn(embed, train_data, dev_data, test_data, config) + + # 4. Load train loss values from the summary files and verify that they + # decrease with training. + summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] + events = summary_test_util.events_from_file(summary_file) + train_losses = [event.summary.value[0].simple_value for event in events + if event.summary.value + and event.summary.value[0].tag == "train/loss"] + self.assertEqual(config.epochs, len(train_losses)) + self.assertLess(train_losses[-1], train_losses[0]) + + +class EagerSpinnSNLIClassifierBenchmark(test.Benchmark): + + def benchmarkEagerSpinnSNLIClassifier(self): + test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + with tf.device(test_device): + burn_in_iterations = 2 + benchmark_iterations = 10 + + vocab_size = 1000 + batch_size = 128 + sequence_length = 15 + d_embed = 200 + d_out = 4 + + embed = tf.random_normal((vocab_size, d_embed)) + + config = _test_spinn_config(d_embed, d_out) + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + for _ in range(burn_in_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + + gc.collect() + start_time = time.time() + for _ in xrange(benchmark_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + wall_time = time.time() - start_time + # Named "examples"_per_sec to conform with other benchmarks. + extras = {"examples_per_sec": benchmark_iterations / wall_time} + self.report_benchmark( + name="Eager_SPINN_SNLIClassifier_Benchmark", + iters=benchmark_iterations, + wall_time=wall_time, + extras=extras) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index 147b7047f42b7ccba5829b61370e82e217ce5838..0095ffa0db99d46d25654d73504d0d7d41c18b6f 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -757,7 +757,7 @@ For example, to record summaries once every 100 global steps, use: ```python tf.train.get_or_create_global_step() # Ensuring the global step variable exists -writer = tf.contrib.summary.create_summary_file_writer(logdir) +writer = tf.contrib.summary.create_file_writer(logdir) for _ in range(iterations): with writer.as_default(): diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index aa359b7a0d7d89e8788c323d1621798d1a22b658..2f8016ede3caee6dbb6fd8f5226f1464b5c3976b 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -73,7 +73,7 @@ class Metric(object): * `result()`: Computes and returns a final value for the metric from the variables in `self`. - Decendants may override `aggregate()`, but usually won't need to. It + Descendants may override `aggregate()`, but usually won't need to. It adds in the state from a list of metrics of the same type as `self`. (Default is to sum all the variables.) Note that users should not call `aggregate()`, it is for use by TensorFlow infrastructure. diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index b4f5973bd11a02230d30f8cf1b2961125f154283..1055f4563cd4608189281450aed512fbf5f31de1 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -67,12 +67,12 @@ class MetricsTest(test.TestCase): m([1, 10, 100]) training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - with summary_ops.create_summary_file_writer( + with summary_ops.create_file_writer( logdir, max_queue=0, name="t0").as_default(), summary_ops.always_record_summaries(): m.result() # As a side-effect will write summaries. - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 37.0) diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index c6e628b074e8638fd15a35f2df87609e0ad46000..e3c13cbd2e8ccd2ab79da74e0e97905c6ed5c02d 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -37,198 +37,98 @@ from tensorflow.python.training import training_util # functions in base.py which should be reused. -_DeferredRestoration = collections.namedtuple( - - "_DeferredRestoration", - [ - # The map_func to use (either user-specified or the default). - "map_func", - # Boolean, True if the user specified an explicit map_func, for error - # messages. - "map_func_is_user", - # A mapping from checkpoint names to initial values of not-yet-created - # variables which should be restored. These values come from parsing a - # checkpoint. - "checkpointed_variables_to_restore", - # A mapping from checkpoint name to variable objects of variables which - # have already been restored, for error checking. - "restored_variables", - # The session to restore with (if in graph mode). - "session", - # Names of the Network where the restore was requested, for error - # messages. - "network_name", - "network_scope_name" - ]) - - -def _default_naming_conflict_error_message( - mapped_name, first_variable, second_variable, - network_name, network_scope_name): - return ( - ("The default checkpoint variable name mapping strategy for Network " - "'%s' resulted in a naming conflict. We attempted to strip off the " - "variable prefix for the Network ('%s'), but this resulted in two " - "variables named '%s' (originally '%s' and '%s'). This should only " - "happen when using variable sharing (i.e. the Network contains Networks " - "or Layers which were first added to another Network, and therefore " - "have that Network's variable prefix). One solution is to pass " - "`map_func=lambda n: n` to Network.save and Network.restore to use " - "fully qualified variable names in the checkpoint, although this will " - "require that the variable prefix of the Network being restored into " - "is also '%s'. You may alternatively write an arbitrary mapping.") - % ( - network_name, network_scope_name, mapped_name, - first_variable._shared_name, - second_variable._shared_name, network_scope_name - )) - - -def _restore_custom_map_func_error_message( - mapped_name, first_variable, second_variable, - network_name, network_scope_name): - return ( - ("The map_func passed to Network.restore for the Network '%s' " - "resulted in two variables named '%s' (originally '%s' and '%s'). Since " - "this is also an error on Network.save, this Network was " - "probably not saved with this map_func. Note that map_func " - "always maps from full variable names to checkpoint names; " - "there is no need to specify an inverse mapping.\n\n" - "Try stripping less from the variable names, or renaming parts " - "of the Network. For reference, variables created by sub-Layers " - "of this Network are prefixed with '%s', but if they are " - "re-used after being added to another Network they will have " - "that Network's full variable prefix instead.") % ( - network_name, mapped_name, - first_variable._shared_name, - second_variable._shared_name, - network_scope_name)) - - -def _make_custom_getter_for_deferred_restorations(): - """Returns a custom getter which searches `deferred_restorations`. - - Returns: A tuple of (_custom_getter, deferred_restorations) - _custom_getter: The getter which should be added to variable_scopes where - variables will be created. - deferred_restorations: A list for _DeferredRestoration objects. Typically - empty when the getter is set, and expanded as deferred restorations are - requested. All new deferred restorations should be appended to the end of - the list, where they will have priority over older deferred restorations. - """ - deferred_restorations = [] - - def _custom_getter(getter, name, shape=None, dtype=None, - initializer=None, - *args, **kwargs): - """A custom getter which processes deferred restorations.""" - # Iterate over restorations, newest first (newer restorations will take - # precedence over older restorations, just like with immediate restorations - # into existing variables). - delayed_restoration = None - found_value = False - value_to_restore = None - for delayed_restoration in reversed( - deferred_restorations): - checkpoint_name = delayed_restoration.map_func(name) - if (checkpoint_name - in delayed_restoration.checkpointed_variables_to_restore): - found_value = True - value_to_restore = ( - delayed_restoration.checkpointed_variables_to_restore[ - checkpoint_name]) - if found_value: - break - # value_to_restore may be False because this variable is not in any - # checkpoint we are restoring, or None because we have explicitly set it to - # None when it was previously fetched. In either case, we don't need to - # set an initializer. - if found_value and value_to_restore is not None: - initializer = value_to_restore - shape = None - variable = getter(name, shape=shape, dtype=dtype, initializer=initializer, - *args, **kwargs) - if found_value and value_to_restore is not None: - # Mark as already restored from this checkpoint. - delayed_restoration.checkpointed_variables_to_restore[ - checkpoint_name] = None - if context.in_graph_mode(): - delayed_restoration.session.run(variable.initializer) - if found_value: - # Error checking should run even if we've already restored a value. - if delayed_restoration.restored_variables.setdefault( - checkpoint_name, variable) is not variable: - # Naming conflict. We've tried to initialize two variables with the - # same value from the checkpoint. - if delayed_restoration.map_func_is_user: - raise ValueError( - _restore_custom_map_func_error_message( - mapped_name=checkpoint_name, - first_variable=delayed_restoration.restored_variables[ - checkpoint_name], - second_variable=variable, - network_name=delayed_restoration.network_name, - network_scope_name=delayed_restoration.network_scope_name)) - else: - raise ValueError( - _default_naming_conflict_error_message( - mapped_name=checkpoint_name, - first_variable=delayed_restoration.restored_variables[ - checkpoint_name], - second_variable=variable, - network_name=delayed_restoration.network_name, - network_scope_name=delayed_restoration.network_scope_name)) - return variable - return _custom_getter, deferred_restorations +def _network_name_scope_naming(current_variable_scope): + """Name scope naming to match operation names to variable names. - -def _make_prefix_stripping_map_fn(scope_name): - """Closure for stripping the scope name of a Network. - - Implemented as a closure rather than a member function to avoid reference - cycles in deferred restorations (this function should not have a reference to - the Network which created it). + Used in Networks and also applied to non-Network Layers which are added to + Networks before being built. Args: - scope_name: The Network.scope_name to strip from variables. + current_variable_scope: A VariableScope object. Returns: - A scope_name-stripping default `map_fn` for the Network. + A name scope name. """ - - def _strip_variable_prefix(original_variable_name): - """The default map_func for saving or restoring variables. - - Strips the variable prefix for the Network on which save/restore was called, - and leaves other variable names fully qualified in the checkpoint. - - Args: - original_variable_name: The _shared_name of the variable (no :0 - suffix) to map. - Returns: - The checkpoint name of the variable. - """ - scope_name_with_slash = scope_name + "/" - if original_variable_name.startswith(scope_name_with_slash): - return original_variable_name[len(scope_name_with_slash):] - else: - return original_variable_name - - return _strip_variable_prefix + return current_variable_scope.name + "/" class Network(base.Layer): """Represents the composition of a set of Layers. - TODO(josh11b,ashankar): - - Should "trainable" be changeable on the Network object? - - Do we allow add_variable in Network? - - Detect layers used in __call__ that weren't registered with track_layer. - - Convert inputs to __call__ to tensors. - - Prevent variables from being created after the first __call__? - (Think about restoring from a checkpoint). + `Network` implements the `Layer` interface and adds convenience methods for + managing sub-`Layer`s, such as listing variables. + + `Layer`s (including other `Network`s) should be added via `track_layer`. They + can then be used when overriding the `Network.call` method: + + ```python + class TwoLayerNetwork(tfe.Network): + + def __init__(self, name): + super(TwoLayerNetwork, self).__init__(name=name) + self.layer_one = self.track_layer(tf.layers.Dense(16, input_shape=(8,))) + self.layer_two = self.track_layer(tf.layers.Dense(1, input_shape=(16,))) + + def call(self, inputs): + return self.layer_two(self.layer_one(inputs)) + ``` + + After constructing an object and calling the `Network`, a list of variables + created by tracked `Layer`s is available via `Network.variables`: + + ```python + net = TwoLayerNetwork(name="net") + output = net(tf.ones([1, 8])) + print([v.name for v in net.variables]) + ``` + + This example prints variable names, one kernel and one bias per + `tf.layers.Dense` layer: + + ``` + ['net/dense/kernel:0', + 'net/dense/bias:0', + 'net/dense_1/kernel:0', + 'net/dense_1/bias:0'] + ``` + + These variables can be passed to a `Saver` (`tf.train.Saver`, or + `tf.contrib.eager.Saver` when executing eagerly) to save or restore the + `Network`, typically alongside a global step and `tf.train.Optimizer` + variables when checkpointing during training. + + Note that the semantics of calling a `Network` with graph execution (i.e. not + executing eagerly) may change slightly in the future. Currently stateful ops + are pruned from the graph unless they or something that depends on them is + executed in a session, but this behavior is not consistent with eager + execution (where stateful ops are executed eagerly). `Layer`s from `tf.layers` + do not depend on this pruning and so will not be affected, but `Network`s + which rely on stateful ops being added to the graph but not executed (e.g. via + custom `Layer`s which manage stateful ops) may break with this change. """ + # TODO(josh11b,ashankar,allenl): + # - Should 'trainable' be changeable on the Network object? + # - Do we allow add_variable in Network? + # - Detect layers used in __call__ that weren't registered with track_layer. + # - Convert inputs to __call__ to tensors. def __init__(self, name=None): + """Configure the `Network`. + + Args: + name: The name to use for this `Network`. If specified, it must be unique + in the context where this `Network` is first + (1) added to another `Network` (in which case it must not share a name + with other `Layers` added to that `Network`), or + (2) built/called (in which case no other 'top-level' `Network`s may + share this name). + If unspecified or None, the `Network` will be named using its class + name, with a number appended if necessary for uniqueness (e.g. MyNetwork + -> 'my_network_1'). + + Raises: + ValueError: If `name` is not valid. Note that some naming errors will + instead be raised when the `Network` is called. + """ if isinstance(name, variable_scope.VariableScope): raise ValueError("VariableScopes are not valid Network names.") if name is not None and "/" in name: @@ -244,8 +144,17 @@ class Network(base.Layer): self._owned_layers = {} # The scope to use if we end up without a parent. self._default_parent_variable_scope = variable_scope.get_variable_scope() - self._custom_getter, self._deferred_restorations = ( - _make_custom_getter_for_deferred_restorations()) + # Hold on to the variable scope counts from init to check whether a scope + # with the name we want was ever created in our parent scope. Without this + # check we might have name collisions if the parent scope on init gets + # closed before build is called. + self._variable_scope_counts_on_init = ( + variable_scope._get_default_variable_store().variable_scopes_count) + + def _name_scope_name(self, current_variable_scope): + """Overrides Layer op naming to match variable naming.""" + return _network_name_scope_naming( + current_variable_scope=current_variable_scope) def _init_set_name(self, name): # Anonymous Networks (name=None) defer setting a final name until they are @@ -261,18 +170,30 @@ class Network(base.Layer): def _finalize_name(self, parent_network): if not self._name: - if not parent_network: - name_uid_map = base._get_default_graph_uid_map() - else: - name_uid_map = parent_network._sub_layer_name_uids # Were were not passed a name explicitly (or it was blank), so this is an # anonymous Network. We make up a unique name. if parent_network: avoid_names = parent_network._owned_layers + name_uid_map = parent_network._sub_layer_name_uids else: - avoid_names = None + name_uid_map = base._get_default_graph_uid_map() + # Figure out which names we have to avoid based on which variable scope + # we're nested in. + strip_name = self._default_parent_variable_scope.name + if strip_name: + strip_name += "/" + def _strip_on_init_scope(name): + if name.startswith(strip_name): + return name[len(strip_name):] + else: + return None + avoid_names = set( + _strip_on_init_scope(name) + for name in self._variable_scope_counts_on_init.keys() if name) self._name, self._base_name = self._make_unique_name( - name_uid_map=name_uid_map, avoid_names=avoid_names) + name_uid_map=name_uid_map, avoid_names=avoid_names, + namespace=self._default_parent_variable_scope.name, + zero_based=True) if self._first_parent is None or (self._first_parent # False = no parent and self._first_parent() is None): # Save a pointer to the parent Network so that we can later check that the @@ -302,7 +223,13 @@ class Network(base.Layer): parent_scope = first_parent._scope else: parent_scope = self._default_parent_variable_scope - with variable_scope.variable_scope(parent_scope): + with variable_scope.variable_scope(parent_scope) as parent_vs: + expected_scope_name = parent_vs.name + "/" + self._name + if expected_scope_name in self._variable_scope_counts_on_init: + raise ValueError( + ("A Network named '%s' already exists (or a variable_scope was " + "created with this name). Names must be unique.") % ( + self._name,)) # Make sure variables with this prefix will be unique. with variable_scope.variable_scope( None, use_resource=True, default_name=self._name) as scope: @@ -319,25 +246,22 @@ class Network(base.Layer): "created with this name). Names must be unique.") % ( self._name,)) if (first_parent - and scope_prefix[:-1] != first_parent._scope.name): + and scope_prefix[:-1] != first_parent.scope_name): raise ValueError( ("Network variable names must match a nesting of sub-Network " "names. Expected prefix '%s' from parent network, but got " "'%s' when attempting to create a variable_scope for Network " "'%s'. Likely an explicit variable_scope was inserted into " "the nesting.") % ( - first_parent._scope.name, + first_parent.scope_name, scope_prefix[:-1], self._name)) elif not first_parent and scope_prefix: # For the case when this Network is not nested inside any other - # Network, but is in a variable_scope. This is an error for now. - raise ValueError( - "Creating Networks inside named variable_scopes is currently " - "not supported (to ensure that variable names match the names " - "of Networks in which they were first created). To set " - "options, try `with tf.variable_scope(''):`. If this " - "limitation bothers you, please file a feature request.") + # Network, but is in a variable_scope. This Network's name takes on + # the full variable scope prefix. + self._name = scope_name + for non_network_sublayer in self._non_network_sublayers: self._set_scope_for_nonnetwork_sublayer(non_network_sublayer) @@ -355,8 +279,7 @@ class Network(base.Layer): raise ValueError( ("The parent of a Layer added to Network %s was garbage collected " "before the Layer was built. If this limitation bothers you " - "please, comment on " - "https://github.com/tensorflow/tensorflow/issues/14164.") % + "please file a feature request.") % (self.name,)) with variable_scope.variable_scope(parent_scope): # Horrid hack to make Layer variable names which are direct @@ -366,6 +289,9 @@ class Network(base.Layer): None, use_resource=True, default_name=sublayer.name) as sub_scope: sublayer._scope = sub_scope + # Also switch op naming for this Layer to match Network conventions, + # i.e. op naming matching variable naming. + sublayer._name_scope_name = _network_name_scope_naming @base.Layer.name.getter def name(self): @@ -420,7 +346,10 @@ class Network(base.Layer): # name, and we should respect it (subject to error checking). layer._name, layer._base_name = layer._make_unique_name( name_uid_map=self._sub_layer_name_uids, - avoid_names=self._owned_layers) + avoid_names=self._owned_layers, + zero_based=True + # No namespace required, since we've specified our own UID map. + ) layer._first_parent = weakref.ref(self) self._non_network_sublayers.append(layer) if (not layer.built @@ -522,254 +451,30 @@ class Network(base.Layer): "at https://github.com/tensorflow/tensorflow/issues/new if this is " "important to you") - def save(self, save_path, global_step=None, map_func=None): - """Save variables from the Network to a checkpoint. + def add_loss(self, losses, inputs=None): + raise RuntimeError( + "add_loss is not supported in Network class yet. Please file an issue " + "at https://github.com/tensorflow/tensorflow/issues/new if this is " + "important to you") - Args: - save_path: Either a checkpoint prefix or the name of a directory to save - the checkpoint in (in which case the checkpoint will be named based on - the Network name). - global_step: The global step to use when naming the checkpoint. If None - (default), we will first try to get the default global step. If that - fails because no default global step exists, then the checkpoint is - created without a global step suffix. - map_func: A function mapping fully qualified variable names - (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By - default (if `map_func=None`), the variable prefix for the network being - restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped - and all other variable names (shared with other Networks) are left - unchanged. - Returns: - The checkpoint prefix for the saved checkpoint, which may be passed to - `Network.restore`. - Raises: - ValueError: If the Network has not yet been called, or if map_func results - in a name collision. - """ - if not self.built: - raise ValueError( - "Attempt to save the Network before it was first called. This means " - "variables have not yet been created, so there is nothing to save.") - self._set_scope() # scope_name should be available to map_funcs - if global_step is None: - global_step = training_util.get_global_step() - if os.path.isdir(save_path): - # If we were passed a directory, default to naming based on the Network - # name. - save_path = os.path.join(save_path, self.name) - user_map_func = map_func - if map_func is None: - map_func = _make_prefix_stripping_map_fn(self.scope_name) - variable_map = {} - for variable in self.variables: - mapped_name = map_func(variable._shared_name) - if variable_map.setdefault(mapped_name, variable) is not variable: - if user_map_func is None: - # Instead of erroring out, we could just re-try and silently use the - # full variable names in the checkpoint. This could be odd for deeply - # nested sub-Networks (since the full prefix from the nesting would - # get added), so for now we'll let the user deal with this case. - raise ValueError(_default_naming_conflict_error_message( - mapped_name=mapped_name, - first_variable=variable_map[mapped_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - else: - # The user passed their own problematic map_func. - raise ValueError( - ("The map_func passed to Network.save for the Network '%s' " - "resulted in two variables named '%s' ('%s' and '%s'). Try " - "stripping less from the variable names, or renaming parts of " - "the Network. For reference, variables created by sub-Layers of " - "this Network are prefixed with '%s', but if they are re-used " - "after being added to another Network, they will have that " - "Network's full variable prefix instead.") % ( - self.name, mapped_name, - variable_map[mapped_name]._shared_name, - variable._shared_name, - self.scope_name)) - if context.in_eager_mode(): - sess = None - else: - sess = ops.get_default_session() - return saver_lib.Saver(variable_map).save( - sess=sess, save_path=save_path, write_meta_graph=False, - global_step=global_step) + @property + def losses(self): + """Gather losses from `Layer`s in the `Network`. - def _restore_existing_variables(self, save_path, map_func, user_map_func): - """Use a standard Saver to restore existing variables from a checkpoint. + Note that when executing eagerly, `Layer.losses` evaluates + regularizers. When using graph execution, variable regularization ops have + already been created and are simply returned here. - Args: - save_path: The checkpoint prefix or directory to read from. - map_func: The function to use when mapping from variable names to - checkpoint names. - user_map_func: The original map_func passed by the user, for error - checking. Returns: - A dictionary mapping from checkpoint names to variable objects which have - been restored (for bookkeeping to avoid deferred restorations on these - variables). - Raises: - ValueError: If there is a name collision. + A list of tensors. """ - existing_variables_by_checkpoint_name = {} - for variable in self.variables: - checkpoint_name = map_func(variable._shared_name) - if existing_variables_by_checkpoint_name.setdefault( - checkpoint_name, variable) is not variable: - if user_map_func is None: - raise ValueError(_default_naming_conflict_error_message( - mapped_name=checkpoint_name, - first_variable=existing_variables_by_checkpoint_name[ - checkpoint_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - else: - raise ValueError(_restore_custom_map_func_error_message( - mapped_name=checkpoint_name, - first_variable=existing_variables_by_checkpoint_name[ - checkpoint_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - if existing_variables_by_checkpoint_name: - if context.in_eager_mode(): - sess = None - else: - sess = ops.get_default_session() - saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore( - sess=sess, save_path=save_path) - return existing_variables_by_checkpoint_name - - def _set_restore_on_create(self, save_path, map_func, user_map_func, - existing_variables_by_checkpoint_name): - """If necessary, request deferred restorations of variables.""" - checkpoint_reader = checkpoint_utils.load_checkpoint(save_path) - checkpointed_variables_to_restore = {} - for checkpoint_name, _ in checkpoint_utils.list_variables(save_path): - if checkpoint_name in existing_variables_by_checkpoint_name: - # This variable was already created and restored. - continue - # Save the variable for later restoration in a custom getter. - checkpointed_variables_to_restore[checkpoint_name] = ( - checkpoint_reader.get_tensor(checkpoint_name)) - # Only set a deferred restoration if there are checkpoint variables which - # have not been assigned to existing variables. Note that this loses out on - # some opportunity for error checking, but avoids creating - # _DeferredRestoration objects once a Network has been built (so that - # restoring in a loop does not take increasing amounts of memory). - if checkpointed_variables_to_restore: - if context.in_eager_mode(): - sess = None - else: - sess = ops.get_default_session() - # We need a name for error messages. If we haven't been added to another - # Network yet, we're top-level. - self._finalize_name(False) - self._set_scope() - # Save a record of this restoration for use in the custom getter. - deferred_restoration = _DeferredRestoration( - map_func=map_func, - map_func_is_user=(user_map_func is not None), - checkpointed_variables_to_restore=checkpointed_variables_to_restore, - restored_variables={}, - session=sess, - network_name=self.name, - network_scope_name=self.scope_name) - self._deferred_restorations.append(deferred_restoration) - # Add the deferred registration to non-Network children, and request that - # Networks propagate the request to their children. - self._add_deferred_restoration(deferred_restoration) - - def _add_deferred_restoration(self, deferred_restoration): - """Add a deferred restoration to this Network and all children. - - Restorations which are requested later have higher priority, and the highest - priority matching restoration is applied to a variable when it is created. - - Args: - deferred_restoration: A _DeferredRestoration object. - """ - # Networks don't create variables at the moment, so this append isn't - # strictly necessary. We could get by with only adding deferred restorations - # to non-Network Layers. - self._set_scope() - # We use set_custom_getter because it avoids recursively calling up the - # variable_scope tree. We've done the tree traversal ourselves and have - # added the request to each Layer which needs it. - self._scope.set_custom_getter(self._custom_getter) - self._deferred_restorations.append(deferred_restoration) + layer_losses = [] for layer in self.layers: - if isinstance(layer, Network): - # For Networks, request that they propagate this deferred restoration - # to all of their children recursively. - layer._add_deferred_restoration(deferred_restoration) - else: - # For non-Network Layers, make sure they have a deferred restoration - # queue and a custom getter, then add our request to it. - if not hasattr(layer, "_custom_getter"): - assert not hasattr(layer, "_deferred_restorations") - layer._custom_getter, layer._deferred_restorations = ( - _make_custom_getter_for_deferred_restorations()) - self._set_scope_for_nonnetwork_sublayer(layer) - layer._scope.set_custom_getter(layer._custom_getter) - layer._deferred_restorations.append(deferred_restoration) - - def restore(self, save_path, map_func=None): - """Restore the Network from a checkpoint. - - If variables have already been created (typically when some or all of the - `Network` is built), they are assigned values from the checkpoint - immediately, overwriting any existing values (in graph mode the default - session is used for the assignments). - - If there are checkpoint entries which do not correspond to any existing - variables in the `Network`, these values are saved for deferred restoration; - their initial values will be the checkpointed values once they are - created. Requests for multiple deferred restorations behave the same way as - immediate restorations, in that later requests will take priority over - earlier requests relevant to the same variable. - - If this `Network` shares `Layer`s with another network, those `Layer`s will - also have their variables restored from the checkpoint. - - Args: - save_path: The return value of `Network.save`, or a directory to search - for a checkpoint. - map_func: A function mapping fully qualified variable names - (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By - default (if `map_func=None`), the variable prefix for the network being - restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped - and all other variable names (shared with other Networks) are left - unchanged. Note that this is the _same_ map_func as `Network.save`, not - an inverse mapping. - """ - self._finalize_name(parent_network=False) - self._set_scope() # scope_name should be available to map_funcs - if os.path.isdir(save_path): - # If we don't have a name yet, set no parent. - save_path = os.path.join(save_path, self.name) - user_map_func = map_func - if map_func is None: - map_func = _make_prefix_stripping_map_fn(self.scope_name) - # Step one is to restore any existing variables from the checkpoint. - existing_variables_by_checkpoint_name = self._restore_existing_variables( - save_path=save_path, - map_func=map_func, - user_map_func=user_map_func) - # Step two is to set a custom getter which restores variables on creation, - # for those variables which have not been added to sub-Layers yet. - self._set_restore_on_create( - save_path=save_path, - map_func=map_func, - user_map_func=user_map_func, - existing_variables_by_checkpoint_name=( - existing_variables_by_checkpoint_name)) + layer_losses.extend(layer.losses) + return layer_losses - # TODO(josh11b): Support other Layer methods needed for graph mode, such as for - # losses and updates + # TODO(allenl): Support other Layer methods needed for graph mode, such as for + # updates class Sequential(Network): @@ -817,3 +522,436 @@ class Sequential(Network): else: inputs = l(inputs) return inputs + + +_DeferredRestoration = collections.namedtuple( + + "_DeferredRestoration", + [ + # The map_func to use (either user-specified or the default). + "map_func", + # Boolean, True if the user specified an explicit map_func, for error + # messages. + "map_func_is_user", + # A mapping from checkpoint names to initial values of not-yet-created + # variables which should be restored. These values come from parsing a + # checkpoint. + "checkpointed_variables_to_restore", + # A mapping from checkpoint name to variable objects of variables which + # have already been restored, for error checking. + "restored_variables", + # The session to restore with (if in graph mode). + "session", + # Names of the Network where the restore was requested, for error + # messages. + "network_name", + "network_scope_name" + ]) + + +def _default_naming_conflict_error_message( + mapped_name, first_variable, second_variable, + network_name, network_scope_name): + return ( + ("The default checkpoint variable name mapping strategy for Network " + "'%s' resulted in a naming conflict. We attempted to strip off the " + "variable prefix for the Network ('%s'), but this resulted in two " + "variables named '%s' (originally '%s' and '%s'). This should only " + "happen when using variable sharing (i.e. the Network contains Networks " + "or Layers which were first added to another Network, and therefore " + "have that Network's variable prefix). One solution is to pass " + "`map_func=lambda n: n` to save and restore to use fully qualified " + "variable names in the checkpoint, although this will require that the " + "variable prefix of the Network being restored into is also '%s'. You " + "may alternatively write an arbitrary mapping.") + % ( + network_name, network_scope_name, mapped_name, + first_variable._shared_name, + second_variable._shared_name, network_scope_name + )) + + +def _restore_custom_map_func_error_message( + mapped_name, first_variable, second_variable, + network_name, network_scope_name): + return ( + ("The map_func passed to restore_network_checkpoint for the Network '%s' " + "resulted in two variables named '%s' (originally '%s' and '%s'). Since " + "this is also an error when saving, this Network was " + "probably not saved with this map_func. Note that map_func " + "always maps from full variable names to checkpoint names; " + "there is no need to specify an inverse mapping.\n\n" + "Try stripping less from the variable names, or renaming parts " + "of the Network. For reference, variables created by sub-Layers " + "of this Network are prefixed with '%s', but if they are " + "re-used after being added to another Network they will have " + "that Network's full variable prefix instead.") % ( + network_name, mapped_name, + first_variable._shared_name, + second_variable._shared_name, + network_scope_name)) + + +def _make_custom_getter_for_deferred_restorations(): + """Returns a custom getter which searches `deferred_restorations`. + + Returns: A tuple of (_custom_getter, deferred_restorations) + _custom_getter: The getter which should be added to variable_scopes where + variables will be created. + deferred_restorations: A list for _DeferredRestoration objects. Typically + empty when the getter is set, and expanded as deferred restorations are + requested. All new deferred restorations should be appended to the end of + the list, where they will have priority over older deferred restorations. + """ + deferred_restorations = [] + + def _custom_getter(getter, name, shape=None, dtype=None, + initializer=None, + *args, **kwargs): + """A custom getter which processes deferred restorations.""" + # Iterate over restorations, newest first (newer restorations will take + # precedence over older restorations, just like with immediate restorations + # into existing variables). + delayed_restoration = None + found_value = False + value_to_restore = None + for delayed_restoration in reversed( + deferred_restorations): + checkpoint_name = delayed_restoration.map_func(name) + if (checkpoint_name + in delayed_restoration.checkpointed_variables_to_restore): + found_value = True + value_to_restore = ( + delayed_restoration.checkpointed_variables_to_restore[ + checkpoint_name]) + if found_value: + break + # value_to_restore may be False because this variable is not in any + # checkpoint we are restoring, or None because we have explicitly set it to + # None when it was previously fetched. In either case, we don't need to + # set an initializer. + if found_value and value_to_restore is not None: + initializer = value_to_restore + shape = None + variable = getter(name, shape=shape, dtype=dtype, initializer=initializer, + *args, **kwargs) + if found_value and value_to_restore is not None: + # Mark as already restored from this checkpoint. + delayed_restoration.checkpointed_variables_to_restore[ + checkpoint_name] = None + if context.in_graph_mode(): + delayed_restoration.session.run(variable.initializer) + if found_value: + # Error checking should run even if we've already restored a value. + if delayed_restoration.restored_variables.setdefault( + checkpoint_name, variable) is not variable: + # Naming conflict. We've tried to initialize two variables with the + # same value from the checkpoint. + if delayed_restoration.map_func_is_user: + raise ValueError( + _restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=delayed_restoration.restored_variables[ + checkpoint_name], + second_variable=variable, + network_name=delayed_restoration.network_name, + network_scope_name=delayed_restoration.network_scope_name)) + else: + raise ValueError( + _default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=delayed_restoration.restored_variables[ + checkpoint_name], + second_variable=variable, + network_name=delayed_restoration.network_name, + network_scope_name=delayed_restoration.network_scope_name)) + return variable + return _custom_getter, deferred_restorations + + +def _make_prefix_stripping_map_fn(scope_name): + """Closure for stripping the scope name of a Network. + + Implemented as a closure rather than a member function to avoid reference + cycles in deferred restorations (this function should not have a reference to + the Network which created it). + + Args: + scope_name: The Network.scope_name to strip from variables. + Returns: + A scope_name-stripping default `map_fn` for the Network. + """ + + def _strip_variable_prefix(original_variable_name): + """The default map_func for saving or restoring variables. + + Strips the variable prefix for the Network on which save/restore was called, + and leaves other variable names fully qualified in the checkpoint. + + Args: + original_variable_name: The _shared_name of the variable (no :0 + suffix) to map. + Returns: + The checkpoint name of the variable. + """ + scope_name_with_slash = scope_name + "/" + if original_variable_name.startswith(scope_name_with_slash): + return original_variable_name[len(scope_name_with_slash):] + else: + return original_variable_name + + return _strip_variable_prefix + + +def save_network_checkpoint( + network, save_path, global_step=None, map_func=None): + """Save variables from the Network to a checkpoint. + + Args: + network: A Network object to save. + save_path: Either a checkpoint prefix or the name of a directory to save + the checkpoint in (in which case the checkpoint will be named based on + the Network name). + global_step: The global step to use when naming the checkpoint. If None + (default), we will first try to get the default global step. If that + fails because no default global step exists, then the checkpoint is + created without a global step suffix. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. + Returns: + The checkpoint prefix for the saved checkpoint, which may be passed to + `Network.restore`. + Raises: + ValueError: If the Network has not yet been called, or if map_func results + in a name collision. + """ + if not network.built: + raise ValueError( + "Attempt to save the Network before it was first called. This means " + "variables have not yet been created, so there is nothing to save.") + network._set_scope() # scope_name should be available to map_funcs + if global_step is None: + global_step = training_util.get_global_step() + if os.path.isdir(save_path): + # If we were passed a directory, default to naming based on the Network + # name. + save_path = os.path.join(save_path, network.name.replace("/", "_")) + user_map_func = map_func + if map_func is None: + map_func = _make_prefix_stripping_map_fn(network.scope_name) + variable_map = {} + for variable in network.variables: + mapped_name = map_func(variable._shared_name) + if variable_map.setdefault(mapped_name, variable) is not variable: + if user_map_func is None: + # Instead of erroring out, we could just re-try and silently use the + # full variable names in the checkpoint. This could be odd for deeply + # nested sub-Networks (since the full prefix from the nesting would + # get added), so for now we'll let the user deal with this case. + raise ValueError(_default_naming_conflict_error_message( + mapped_name=mapped_name, + first_variable=variable_map[mapped_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) + else: + # The user passed their own problematic map_func. + raise ValueError( + ("The map_func passed to save_network_checkpoint for the Network " + "'%s' resulted in two variables named '%s' ('%s' and '%s'). Try " + "stripping less from the variable names, or renaming parts of " + "the Network. For reference, variables created by sub-Layers of " + "this Network are prefixed with '%s', but if they are re-used " + "after being added to another Network, they will have that " + "Network's full variable prefix instead.") % ( + network.name, mapped_name, + variable_map[mapped_name]._shared_name, + variable._shared_name, + network.scope_name)) + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + return saver_lib.Saver(variable_map).save( + sess=sess, save_path=save_path, write_meta_graph=False, + global_step=global_step) + + +def _add_deferred_restoration(layer, deferred_restoration): + """Add a deferred restoration to this Layer and all children. + + Restorations which are requested later have higher priority, and the highest + priority matching restoration is applied to a variable when it is created. + + Args: + layer: The Layer (may not be a Network) to operate on. + deferred_restoration: A _DeferredRestoration object. + """ + # Networks don't create variables at the moment, so this append isn't strictly + # necessary. We could get by with only adding deferred restorations to + # non-Network Layers. + if isinstance(layer, Network): + layer._set_scope() + # Make sure this Layer has a deferred restoration queue and a custom getter, + # then add our request to it. + if not hasattr(layer, "_custom_getter"): + assert not hasattr(layer, "_deferred_restorations") + layer._custom_getter, layer._deferred_restorations = ( + _make_custom_getter_for_deferred_restorations()) + # We use set_custom_getter because it avoids recursively calling up the + # variable_scope tree. We've done the tree traversal ourselves and have added + # the request to each Layer which needs it. + layer._scope.set_custom_getter(layer._custom_getter) + layer._deferred_restorations.append(deferred_restoration) + if isinstance(layer, Network): + for sublayer in layer.layers: + if not isinstance(sublayer, Network): + layer._set_scope_for_nonnetwork_sublayer(sublayer) + _add_deferred_restoration(sublayer, deferred_restoration) + + +def _restore_existing_variables(network, save_path, map_func, user_map_func): + """Use a standard Saver to restore existing variables from a checkpoint. + + Args: + network: A Network object to restore. + save_path: The checkpoint prefix or directory to read from. + map_func: The function to use when mapping from variable names to + checkpoint names. + user_map_func: The original map_func passed by the user, for error + checking. + Returns: + A dictionary mapping from checkpoint names to variable objects which have + been restored (for bookkeeping to avoid deferred restorations on these + variables). + Raises: + ValueError: If there is a name collision. + """ + existing_variables_by_checkpoint_name = {} + for variable in network.variables: + checkpoint_name = map_func(variable._shared_name) + if existing_variables_by_checkpoint_name.setdefault( + checkpoint_name, variable) is not variable: + if user_map_func is None: + raise ValueError(_default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) + else: + raise ValueError(_restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) + if existing_variables_by_checkpoint_name: + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore( + sess=sess, save_path=save_path) + return existing_variables_by_checkpoint_name + + +def _set_restore_on_create(network, save_path, map_func, user_map_func, + existing_variables_by_checkpoint_name): + """If necessary, request deferred restorations of variables.""" + checkpoint_reader = checkpoint_utils.load_checkpoint(save_path) + checkpointed_variables_to_restore = {} + for checkpoint_name, _ in checkpoint_utils.list_variables(save_path): + if checkpoint_name in existing_variables_by_checkpoint_name: + # This variable was already created and restored. + continue + # Save the variable for later restoration in a custom getter. + checkpointed_variables_to_restore[checkpoint_name] = ( + checkpoint_reader.get_tensor(checkpoint_name)) + # Only set a deferred restoration if there are checkpoint variables which + # have not been assigned to existing variables. Note that this loses out on + # some opportunity for error checking, but avoids creating + # _DeferredRestoration objects once a Network has been built (so that + # restoring in a loop does not take increasing amounts of memory). + if checkpointed_variables_to_restore: + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + # We need a name for error messages. If we haven't been added to another + # Network yet, we're top-level. + network._finalize_name(False) + network._set_scope() + # Save a record of this restoration for use in the custom getter. + deferred_restoration = _DeferredRestoration( + map_func=map_func, + map_func_is_user=(user_map_func is not None), + checkpointed_variables_to_restore=checkpointed_variables_to_restore, + restored_variables={}, + session=sess, + network_name=network.name, + network_scope_name=network.scope_name) + # Add the deferred registration to non-Network children, and request that + # Networks propagate the request to their children. + _add_deferred_restoration(network, deferred_restoration) + + +def restore_network_checkpoint(network, save_path, map_func=None): + """Restore the Network from a checkpoint. + + If variables have already been created (typically when some or all of the + `Network` is built), they are assigned values from the checkpoint immediately, + overwriting any existing values (in graph mode the default session is used for + the assignments). + + If there are checkpoint entries which do not correspond to any existing + variables in the `Network`, these values are saved for deferred restoration; + their initial values will be the checkpointed values once they are + created. Requests for multiple deferred restorations behave the same way as + immediate restorations, in that later requests will take priority over earlier + requests relevant to the same variable. + + If this `Network` shares `Layer`s with another network, those `Layer`s will + also have their variables restored from the checkpoint. + + Args: + network: A Network object to restore. + save_path: The return value of `tfe.save_network_checkpoint`, or a directory + to search for a checkpoint. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. Note that this is the _same_ map_func as + `tfe.save_network_checkpoint`, not an inverse mapping. + """ + network._finalize_name(parent_network=False) + network._set_scope() # scope_name should be available to map_funcs + if os.path.isdir(save_path): + # If we don't have a name yet, set no parent. + save_path = os.path.join(save_path, network.name.replace("/", "_")) + user_map_func = map_func + if map_func is None: + map_func = _make_prefix_stripping_map_fn(network.scope_name) + # Step one is to restore any existing variables from the checkpoint. + existing_variables_by_checkpoint_name = _restore_existing_variables( + network=network, + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func) + # Step two is to set a custom getter which restores variables on creation, + # for those variables which have not been added to sub-Layers yet. + _set_restore_on_create( + network=network, + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func, + existing_variables_by_checkpoint_name=( + existing_variables_by_checkpoint_name)) diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 14adbafe5735bd2a3d3961402e8ef3e6a7be333b..3eb4f5f8b3954a7ed04d2ef1d4f119ad137e1e65 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -19,9 +19,13 @@ from __future__ import print_function import gc from tensorflow.contrib.eager.python import network +from tensorflow.contrib.layers.python.layers import regularizers +from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core from tensorflow.python.ops import math_ops @@ -42,12 +46,28 @@ class MyNetwork(network.Network): return self.l1(x) +class RegularizedNetwork(network.Network): + + def __init__(self): + super(RegularizedNetwork, self).__init__() + self.l1 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0), + kernel_regularizer=regularizers.l1_regularizer(2.0))) + self.l2 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0))) + + def call(self, values): + return self.l2(self.l1(values)) + + class NetworkTest(test.TestCase): def _save_modify_load_network_built(self, net, global_step=None): checkpoint_directory = self.get_temp_dir() - checkpoint_path = net.save( - save_path=checkpoint_directory, global_step=global_step) + checkpoint_path = network.save_network_checkpoint( + network=net, save_path=checkpoint_directory, global_step=global_step) input_value = constant_op.constant([[42.0]]) original_output = self.evaluate(net(input_value)) for var in net.variables: @@ -56,13 +76,13 @@ class NetworkTest(test.TestCase): self.evaluate(net(input_value)), original_output) # Either the returned explicit checkpoint path or the directory should work. - net.restore(save_path=checkpoint_directory) + network.restore_network_checkpoint(net, save_path=checkpoint_directory) self.assertAllEqual( original_output, self.evaluate(net(input_value))) for var in net.variables: self.evaluate(var.assign(var + 2.)) - net.restore(save_path=checkpoint_path) + network.restore_network_checkpoint(net, save_path=checkpoint_path) self.assertAllEqual( original_output, self.evaluate(net(input_value))) @@ -85,13 +105,30 @@ class NetworkTest(test.TestCase): result = net(constant_op.constant([[2.0]])) self.assertEqual(34.0, self.evaluate(result)) + # TODO(akshayka): This test should be changed once an API for compiling + # `call` into a defun is implemented. + def testReplacingNetworkCallWithDefun(self): + net = MyNetwork(name="abcd") + x = constant_op.constant([[2.0]]) + net(x) # Force variables to be created. + self.evaluate(net.trainable_variables[0].assign([[17.0]])) + + net.call = function.defun(net.call) + result = net(x) # Build and execute the TensorFlow function + self.assertEqual(34.0, self.evaluate(result)) + + # Force the creation of another TensorFlow function by changing input shape + y = constant_op.constant([[1.0], [2.0]]) + result = net(y) + self.assertAllEqual([[17.0], [34.0]], self.evaluate(result)) + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testNetworkSaveRestoreAlreadyBuilt(self): net = MyNetwork(name="abcd") with self.assertRaisesRegexp( ValueError, "Attempt to save the Network before it was first called"): - net.save(self.get_temp_dir()) + network.save_network_checkpoint(net, self.get_temp_dir()) net(constant_op.constant([[2.0]])) self.evaluate(net.trainable_variables[0].assign([[17.0]])) self._save_modify_load_network_built(net, global_step=None) @@ -105,7 +142,7 @@ class NetworkTest(test.TestCase): self.evaluate(net.variables[0].assign([[3.]])) default_global_step = training_util.get_or_create_global_step() self.evaluate(default_global_step.assign(4242)) - save_path = net.save(self.get_temp_dir()) + save_path = network.save_network_checkpoint(net, self.get_temp_dir()) self.assertIn("abcd-4242", save_path) # TODO(allenl): This test creates garbage in some Python versions @@ -116,16 +153,43 @@ class NetworkTest(test.TestCase): test_input = constant_op.constant([[2.0]]) net1(test_input) self.evaluate(net1.trainable_variables[0].assign([[17.0]])) - save_path = net1.save(save_dir) + save_path = network.save_network_checkpoint(net1, save_dir) # With a pre-build restore we should have the same value. net2 = MyNetwork() - net2.restore(save_path) + network.restore_network_checkpoint(net2, save_path) self.assertAllEqual(self.evaluate(net1(test_input)), self.evaluate(net2(test_input))) self.assertIsNot(net1.variables[0], net2.variables[0]) self.assertAllEqual(self.evaluate(net1.variables[0]), self.evaluate(net2.variables[0])) + @test_util.run_in_graph_and_eager_modes() + def testNetworkMatchesLayerVariableNames(self): + zero = constant_op.constant([[0.]]) + layer_one = core.Dense(1, use_bias=False) + layer_one(zero) + layer_two = core.Dense(1, use_bias=False) + layer_two(zero) + + class TwoLayerNet(network.Network): + + def __init__(self, name=None): + super(TwoLayerNet, self).__init__(name=name) + self.first = self.track_layer(core.Dense( + 1, use_bias=False)) + self.second = self.track_layer(core.Dense( + 1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + net = TwoLayerNet() + net(zero) + self.assertEqual("two_layer_net/" + layer_one.variables[0].name, + net.first.variables[0].name) + self.assertEqual("two_layer_net/" + layer_two.variables[0].name, + net.second.variables[0].name) + @test_util.run_in_graph_and_eager_modes() def testLoadIntoUnbuiltSharedLayer(self): @@ -173,14 +237,15 @@ class NetworkTest(test.TestCase): # Re-map the variable names so that with default restore mapping we'll # attempt to restore into the unbuilt Layer. name_mapping = { - "checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel", + "checkpoint_creator/first_layer/kernel": "owner/first_layer/kernel", "checkpoint_creator/second_layer/kernel": "second_layer/kernel", } - save_path = checkpoint_creator.save( + save_path = network.save_network_checkpoint( + checkpoint_creator, self.get_temp_dir(), map_func=lambda full_name: name_mapping[full_name]) load_into = User(use_layer=first_owner.first) - load_into.restore(save_path) + network.restore_network_checkpoint(load_into, save_path) self.assertEqual(0, len(first_owner.variables)) self.assertAllEqual(self.evaluate(checkpoint_creator(one)), self.evaluate(load_into(one))) @@ -196,12 +261,13 @@ class NetworkTest(test.TestCase): del first_owner gc.collect() def _restore_map_func(original_name): - if original_name.startswith("owner_1"): - return original_name.replace("owner_1", "owner_2") + if original_name.startswith("owner/"): + return original_name.replace("owner/", "owner_1/") else: - return "user_2/" + original_name + return "user_1/" + original_name with self.assertRaisesRegexp(ValueError, "garbage collected"): - load_into.restore(save_path, map_func=_restore_map_func) + network.restore_network_checkpoint( + load_into, save_path, map_func=_restore_map_func) @test_util.run_in_graph_and_eager_modes() def testRestoreIntoSubNetwork(self): @@ -221,17 +287,18 @@ class NetworkTest(test.TestCase): whole_model_saver(one) self.evaluate(whole_model_saver.variables[0].assign([[15.]])) self.evaluate(whole_model_saver.variables[1].assign([[16.]])) - whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir()) + whole_model_checkpoint = network.save_network_checkpoint( + whole_model_saver, self.get_temp_dir()) save_from = MyNetwork() save_from(one) self.evaluate(save_from.variables[0].assign([[5.]])) - checkpoint = save_from.save(self.get_temp_dir()) + checkpoint = network.save_network_checkpoint(save_from, self.get_temp_dir()) save_into_parent = Parent() - save_into_parent.restore(whole_model_checkpoint) - save_into_parent.first.restore(checkpoint) - save_into_parent.first.restore(checkpoint) # deferred loading multiple - # times is fine + network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint) + network.restore_network_checkpoint(save_into_parent.first, checkpoint) + # deferred loading multiple times is fine + network.restore_network_checkpoint(save_into_parent.first, checkpoint) save_into_parent(one) # deferred loading self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0])) self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1])) @@ -240,9 +307,9 @@ class NetworkTest(test.TestCase): # (deferred restoration should happen the same way non-deferred happens, # with later restorations overwriting older ones). save_into_parent = Parent() - save_into_parent.first.restore(checkpoint) # deferred loading multiple - # times is fine - save_into_parent.restore(whole_model_checkpoint) + # deferred loading multiple times is fine + network.restore_network_checkpoint(save_into_parent.first, checkpoint) + network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint) save_into_parent(one) # deferred loading # We've overwritten the sub-Network restore. self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0])) @@ -250,12 +317,12 @@ class NetworkTest(test.TestCase): self.evaluate(save_into_parent.variables[0].assign([[3.]])) self.evaluate(save_into_parent.variables[1].assign([[4.]])) - save_into_parent.second.restore(checkpoint) + network.restore_network_checkpoint(save_into_parent.second, checkpoint) self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1])) with self.assertRaisesRegexp(errors_impl.NotFoundError, "not found in checkpoint"): # The checkpoint is incompatible. - save_into_parent.restore(checkpoint) + network.restore_network_checkpoint(save_into_parent, checkpoint) @test_util.run_in_graph_and_eager_modes() def testCustomMapCollisionErrors(self): @@ -277,31 +344,36 @@ class NetworkTest(test.TestCase): self.evaluate(make_checkpoint.variables[1].assign([[3.]])) with self.assertRaisesRegexp( ValueError, - "The map_func passed to Network.save for the Network 'parent_1' " - "resulted in two variables named 'foo'"): - make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo") - checkpoint = make_checkpoint.first.save( - self.get_temp_dir(), map_func=lambda n: "foo") + "The map_func passed to save_network_checkpoint for the Network " + "'parent' resulted in two variables named 'foo'"): + network.save_network_checkpoint( + make_checkpoint, self.get_temp_dir(), map_func=lambda n: "foo") + checkpoint = network.save_network_checkpoint( + network=make_checkpoint.first, + save_path=self.get_temp_dir(), + map_func=lambda n: "foo") loader = Parent() - loader.restore(checkpoint, map_func=lambda n: "foo") + network.restore_network_checkpoint( + loader, checkpoint, map_func=lambda n: "foo") with self.assertRaisesRegexp( ValueError, - ("The map_func passed to Network.restore for the Network" - " 'parent_2' resulted in two variables named 'foo'")): + ("The map_func passed to restore_network_checkpoint for the Network" + " 'parent_1' resulted in two variables named 'foo'")): loader(one) loader = Parent() loader(one) with self.assertRaisesRegexp( ValueError, - ("The map_func passed to Network.restore for the Network" - " 'parent_3' resulted in two variables named 'foo'")): - loader.restore(checkpoint, map_func=lambda n: "foo") + ("The map_func passed to restore_network_checkpoint for the Network" + " 'parent_2' resulted in two variables named 'foo'")): + network.restore_network_checkpoint( + loader, checkpoint, map_func=lambda n: "foo") @test_util.run_in_graph_and_eager_modes() def testDefaultMapCollisionErrors(self): one = constant_op.constant([[1.]]) - first = core.Dense(1, name="dense_1", use_bias=False) + first = core.Dense(1, name="dense", use_bias=False) first(one) class Parent(network.Network): @@ -322,8 +394,8 @@ class NetworkTest(test.TestCase): with self.assertRaisesRegexp( ValueError, ("The default checkpoint variable name mapping strategy for Network " - "'parent_1' resulted in a naming conflict.")): - make_checkpoint.save(self.get_temp_dir()) + "'parent' resulted in a naming conflict.")): + network.save_network_checkpoint(make_checkpoint, self.get_temp_dir()) class Compatible(network.Network): @@ -337,14 +409,15 @@ class NetworkTest(test.TestCase): successful_checkpoint = Compatible() successful_checkpoint(one) self.evaluate(successful_checkpoint.variables[0].assign([[-1.]])) - checkpoint_path = successful_checkpoint.save(self.get_temp_dir()) + checkpoint_path = network.save_network_checkpoint( + successful_checkpoint, self.get_temp_dir()) load_checkpoint = Parent() load_checkpoint(one) with self.assertRaisesRegexp( ValueError, ("The default checkpoint variable name mapping strategy for Network " - "'parent_2' resulted in a naming conflict.")): - load_checkpoint.restore(checkpoint_path) + "'parent_1' resulted in a naming conflict.")): + network.restore_network_checkpoint(load_checkpoint, checkpoint_path) def testNoReferenceCyclesAfterCall(self): @@ -398,6 +471,48 @@ class NetworkTest(test.TestCase): self.assertIsInstance(net.trainable_weights[0], resource_variable_ops.ResourceVariable) + def testGraphOpNames(self): + """Network operation names should match variable naming.""" + + def _check_op_prefixes(expected_prefix, checked_ops): + for operation in ops.get_default_graph().get_operations(): + if operation.name == "ignore": + continue + if operation.name in checked_ops: + continue + checked_ops.add(operation.name) + self.assertStartsWith(expected_start=expected_prefix, + actual=operation.name) + self.assertNotIn("my_network", operation.name[len(expected_prefix):]) + self.assertNotIn("dense", operation.name[len(expected_prefix):]) + + with context.graph_mode(): + net = MyNetwork() + zero = constant_op.constant([[0.]], name="ignore") + net(zero) + checked_ops = set() + _check_op_prefixes(expected_prefix="my_network/dense/", + checked_ops=checked_ops) + net.net2 = net.track_layer(MyNetwork()) + net.net2(zero) + _check_op_prefixes(expected_prefix="my_network/my_network/dense/", + checked_ops=checked_ops) + MyNetwork()(zero) + _check_op_prefixes(expected_prefix="my_network_1/dense/", + checked_ops=checked_ops) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableRegularizers(self): + net = RegularizedNetwork() + net(constant_op.constant([[1.]])) + self.evaluate(net.variables[0].assign([[2.]])) + self.evaluate(net.variables[1].assign([3.])) + self.evaluate(net.variables[2].assign([[-2.]])) + self.evaluate(net.variables[3].assign([4.])) + self.assertAllEqual([4., 6., 8.], self.evaluate(net.losses)) + self.evaluate(net.variables[3].assign([5.])) + self.assertAllEqual([4., 6., 10.], self.evaluate(net.losses)) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testDuplicateNameError(self): one = constant_op.constant([[1.]]) @@ -410,19 +525,103 @@ class NetworkTest(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testWrappingInVariableScope(self): + one = constant_op.constant([[1.]]) + # Naming happens in the order of first build rather than the order of + # construction, but for clarity they're the same here and construction is + # annotated. + outside_net_before = MyNetwork() # name=my_network + outside_net_before(one) + captured_scope = variable_scope.get_variable_scope() with variable_scope.variable_scope("outside_scope"): - net = MyNetwork() - one = constant_op.constant([[1.]]) - with self.assertRaisesRegexp( - ValueError, - ("Creating Networks inside named variable_scopes is currently not " - "supported")): - net(one) - # Alternatively, we could re-name the Network to match the variable_scope: - # self.assertEqual("outside_scope/my_network_1", net.name) - # self.assertStartsWith( - # expected_start="outside_scope/my_network_1/dense/", - # actual=net.trainable_weights[0].name) + net1 = MyNetwork() # name=outside_scope/my_network + net1(one) + name_conflict1 = MyNetwork(name="name_conflict") # fine, unique so far + name_conflict2 = MyNetwork(name="name_conflict") # error on build + with variable_scope.variable_scope("inside_scope"): + # No issue here since the name is unique within its scope. + name_conflict3 = MyNetwork(name="name_conflict") + net2 = MyNetwork() # name=outside_scope/my_network_2 to avoid the + # variable_scope my_network_1 below. + vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below + with variable_scope.variable_scope("intervening_scope"): + with variable_scope.variable_scope(captured_scope): + with variable_scope.variable_scope("outside_scope"): + name_conflict4 = MyNetwork(name="name_conflict") # error on build + with variable_scope.variable_scope("my_network_1"): + pass + with variable_scope.variable_scope("vs_name_conflict"): + pass + net3 = MyNetwork() # name=outside_scope/my_network_4 + name_conflict1(one) + with self.assertRaisesRegexp( + ValueError, "named 'name_conflict' already exists"): + name_conflict2(one) + name_conflict3(one) + net2(one) + with self.assertRaisesRegexp( + ValueError, "or a variable_scope was created with this name"): + vs_name_conflict(one) + with self.assertRaisesRegexp( + ValueError, "named 'name_conflict' already exists"): + name_conflict4(one) + self.assertEqual("outside_scope/name_conflict", + name_conflict1.name) + self.assertStartsWith( + expected_start="outside_scope/name_conflict/dense/", + actual=name_conflict1.variables[0].name) + self.assertEqual("outside_scope/inside_scope/name_conflict", + name_conflict3.name) + self.assertStartsWith( + expected_start="outside_scope/inside_scope/name_conflict/dense/", + actual=name_conflict3.variables[0].name) + self.assertEqual("outside_scope/my_network", net1.name) + self.assertStartsWith( + expected_start="outside_scope/my_network/dense/", + actual=net1.trainable_weights[0].name) + self.assertEqual("outside_scope/my_network_2", net2.name) + self.assertStartsWith( + expected_start="outside_scope/my_network_2/dense/", + actual=net2.trainable_weights[0].name) + net3(one) + self.assertEqual("outside_scope/my_network_3", net3.name) + self.assertStartsWith( + expected_start="outside_scope/my_network_3/dense/", + actual=net3.trainable_weights[0].name) + outside_net_after = MyNetwork() + outside_net_after(one) + self.assertEqual("my_network", outside_net_before.name) + self.assertStartsWith( + expected_start="my_network/dense/", + actual=outside_net_before.trainable_weights[0].name) + self.assertEqual("my_network_1", outside_net_after.name) + self.assertStartsWith( + expected_start="my_network_1/dense/", + actual=outside_net_after.trainable_weights[0].name) + + @test_util.run_in_graph_and_eager_modes() + def testVariableScopeStripping(self): + with variable_scope.variable_scope("scope1"): + with variable_scope.variable_scope("scope2"): + net = MyNetwork() + net(constant_op.constant([[2.0]])) + self.evaluate(net.variables[0].assign([[42.]])) + self.assertEqual(net.name, "scope1/scope2/my_network") + self.assertStartsWith( + expected_start="scope1/scope2/my_network/dense/", + actual=net.trainable_weights[0].name) + save_path = network.save_network_checkpoint(net, self.get_temp_dir()) + self.assertIn("scope1_scope2_my_network", save_path) + restore_net = MyNetwork() + # Delayed restoration + network.restore_network_checkpoint(restore_net, save_path) + restore_net(constant_op.constant([[1.0]])) + self.assertAllEqual([[42.]], + self.evaluate(restore_net.variables[0])) + self.evaluate(restore_net.variables[0].assign([[-1.]])) + # Immediate restoration + network.restore_network_checkpoint(restore_net, save_path) + self.assertAllEqual([[42.]], + self.evaluate(restore_net.variables[0])) @test_util.run_in_graph_and_eager_modes() def testLayerNamesRespected(self): @@ -439,7 +638,7 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = ParentNetwork() net(one) - self.assertStartsWith(expected_start="parent_network_1/explicit_name/", + self.assertStartsWith(expected_start="parent_network/explicit_name/", actual=net.trainable_weights[0].name) self.assertEqual("explicit_name", net.first.name) @@ -494,15 +693,15 @@ class NetworkTest(test.TestCase): # locally so that previous Layer consutrciton does not interfere with # variable naming (e.g. add a Layer construction before the Network, # suddenly your previously saved checkpoint is incompatible). - self.assertEqual("dense_1", net1.l1.name) - self.assertEqual("dense_1", net2.l1.name) + self.assertEqual("dense", net1.l1.name) + self.assertEqual("dense", net2.l1.name) self.evaluate(net1.trainable_weights[0].assign([[1.]])) self.evaluate(net2.trainable_weights[0].assign([[2.]])) self.assertEqual(2., self.evaluate(net2.trainable_weights[0])) self.assertEqual(1., self.evaluate(net1.trainable_weights[0])) - self.assertStartsWith(expected_start="my_network_1/dense_1/", + self.assertStartsWith(expected_start="my_network/dense/", actual=net1.trainable_weights[0].name) - self.assertStartsWith(expected_start="my_network_2/dense_1/", + self.assertStartsWith(expected_start="my_network_1/dense/", actual=net2.trainable_weights[0].name) @test_util.run_in_graph_and_eager_modes() @@ -523,31 +722,31 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = ParentNetwork() net(one) - self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network/my_network/dense", actual=net.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network/my_network/dense", actual=net.first.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network/my_network_1/dense", actual=net.trainable_weights[1].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network/my_network_1/dense", actual=net.second.trainable_weights[0].name) - self.assertEqual("parent_network_1", net.name) - self.assertEqual("my_network_1", net.first.name) - self.assertEqual("my_network_2", net.second.name) + self.assertEqual("parent_network", net.name) + self.assertEqual("my_network", net.first.name) + self.assertEqual("my_network_1", net.second.name) net2 = ParentNetwork() net2(one) - self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network/dense", actual=net2.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network/dense", actual=net2.first.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", actual=net2.trainable_weights[1].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", actual=net2.second.trainable_weights[0].name) - self.assertEqual("parent_network_2", net2.name) - self.assertEqual("my_network_1", net2.first.name) - self.assertEqual("my_network_2", net2.second.name) + self.assertEqual("parent_network_1", net2.name) + self.assertEqual("my_network", net2.first.name) + self.assertEqual("my_network_1", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testNestableExplicit(self): @@ -608,26 +807,26 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = MixedLayerNetwork() net(one) - self.assertEqual("dense_1", net.first.name) - self.assertEqual("dense_2", net.second.name) - self.assertEqual("dense_3", net.third.name) - self.assertEqual("dense_4", net.fourth.name) - self.assertEqual("dense_5", net.fifth.name) + self.assertEqual("dense", net.first.name) + self.assertEqual("dense_1", net.second.name) + self.assertEqual("dense_2", net.third.name) + self.assertEqual("dense_3", net.fourth.name) + self.assertEqual("dense_4", net.fifth.name) # Note that this is _not_ the default naming behavior for Layers. Layers # which are added to Networks follow Network variable naming conventions # (i.e. variable names = network name unless variable sharing). Nested # Layers revert to Layer behavior. - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_1/", + self.assertStartsWith(expected_start="mixed_layer_network/dense/", actual=net.trainable_weights[0].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_2/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_1/", actual=net.trainable_weights[1].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_3/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_2/", actual=net.trainable_weights[2].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_4/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_3/", actual=net.trainable_weights[3].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_5/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_4/", actual=net.trainable_weights[4].name) - self.assertEqual("mixed_layer_network_1", net.name) + self.assertEqual("mixed_layer_network", net.name) @test_util.run_in_graph_and_eager_modes() def testNestableExplicitCollisions(self): @@ -680,24 +879,24 @@ class NetworkTest(test.TestCase): net = ParentNetwork() net(one) self.assertStartsWith( - expected_start="parent_network_1/first_unique_child_name/dense_1/", + expected_start="parent_network/first_unique_child_name/dense/", actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start="parent_network_1/second_unique_child_name/dense_1/", + expected_start="parent_network/second_unique_child_name/dense/", actual=net.trainable_weights[1].name) - self.assertEqual("parent_network_1", net.name) + self.assertEqual("parent_network", net.name) self.assertEqual("first_unique_child_name", net.first.name) self.assertEqual("second_unique_child_name", net.second.name) net2 = ParentNetwork() net2(one) self.assertStartsWith( - expected_start="parent_network_2/first_unique_child_name/dense", + expected_start="parent_network_1/first_unique_child_name/dense", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="parent_network_2/second_unique_child_name/dense", + expected_start="parent_network_1/second_unique_child_name/dense", actual=net2.trainable_weights[1].name) - self.assertEqual("parent_network_2", net2.name) + self.assertEqual("parent_network_1", net2.name) self.assertEqual("first_unique_child_name", net2.first.name) self.assertEqual("second_unique_child_name", net2.second.name) @@ -755,15 +954,15 @@ class NetworkTest(test.TestCase): net2(one) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_1/dense_1/", + expected_start="first_parent_network/my_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_parent_network_1/my_network_1/dense_1/", + expected_start="second_parent_network/my_network/dense/", actual=net2.trainable_weights[1].name) - self.assertEqual("second_parent_network_1", net2.name) + self.assertEqual("second_parent_network", net2.name) self.assertTrue(net2.first is net.first) - self.assertEqual("my_network_1", net2.first.name) - self.assertEqual("my_network_1", net2.second.name) + self.assertEqual("my_network", net2.first.name) + self.assertEqual("my_network", net2.second.name) # No name collision; the owned Network is added first and has a different # name than the shared Network. @@ -781,15 +980,15 @@ class NetworkTest(test.TestCase): net3(one) self.assertStartsWith( - expected_start="third_parent_network_1/my_network_1/dense", + expected_start="third_parent_network/my_network/dense", actual=net3.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_2/dense", + expected_start="first_parent_network/my_network_1/dense", actual=net3.trainable_weights[1].name) - self.assertEqual("third_parent_network_1", net3.name) + self.assertEqual("third_parent_network", net3.name) self.assertTrue(net3.second is net.second) - self.assertEqual("my_network_1", net3.first.name) - self.assertEqual("my_network_2", net3.second.name) + self.assertEqual("my_network", net3.first.name) + self.assertEqual("my_network_1", net3.second.name) # "Unavoidable" same-name Layer. The owned name is added first (fixed), then # a shared Network is added with the same name. @@ -807,15 +1006,15 @@ class NetworkTest(test.TestCase): net4(one) self.assertStartsWith( - expected_start="fourth_parent_network_1/my_network_1/dense_1/", + expected_start="fourth_parent_network/my_network/dense/", actual=net4.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_1/dense_1/", + expected_start="first_parent_network/my_network/dense/", actual=net4.trainable_weights[1].name) - self.assertEqual("fourth_parent_network_1", net4.name) + self.assertEqual("fourth_parent_network", net4.name) self.assertTrue(net4.second is net.first) - self.assertEqual("my_network_1", net4.first.name) - self.assertEqual("my_network_1", net4.second.name) + self.assertEqual("my_network", net4.first.name) + self.assertEqual("my_network", net4.second.name) @test_util.run_in_graph_and_eager_modes() def testRecursiveLayerRenaming(self): @@ -846,28 +1045,28 @@ class NetworkTest(test.TestCase): net(one) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_1/" - "dense_1/"), + expected_start=("parent_network/network_with_layer_children/" + "dense/"), actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_1/" - "dense_2/"), + expected_start=("parent_network/network_with_layer_children/" + "dense_1/"), actual=net.trainable_weights[1].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_2/" - "dense_1/"), + expected_start=("parent_network/network_with_layer_children_1/" + "dense/"), actual=net.trainable_weights[2].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_2/" - "dense_2/"), + expected_start=("parent_network/network_with_layer_children_1/" + "dense_1/"), actual=net.trainable_weights[3].name) - self.assertEqual("parent_network_1", net.name) - self.assertEqual("network_with_layer_children_1", net.first.name) - self.assertEqual("network_with_layer_children_2", net.second.name) - self.assertEqual("dense_1", net.first.first.name) - self.assertEqual("dense_2", net.first.second.name) - self.assertEqual("dense_1", net.second.first.name) - self.assertEqual("dense_2", net.second.second.name) + self.assertEqual("parent_network", net.name) + self.assertEqual("network_with_layer_children", net.first.name) + self.assertEqual("network_with_layer_children_1", net.second.name) + self.assertEqual("dense", net.first.first.name) + self.assertEqual("dense_1", net.first.second.name) + self.assertEqual("dense", net.second.first.name) + self.assertEqual("dense_1", net.second.second.name) @test_util.run_in_graph_and_eager_modes() def testCallInDifferentOrderThanConstruct(self): @@ -901,23 +1100,23 @@ class NetworkTest(test.TestCase): net1(one) self.assertStartsWith( - expected_start="first_network_1/my_network_1/dense_1/", + expected_start="first_network/my_network/dense/", actual=net1.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/my_network_2/dense_1/", + expected_start="first_network/my_network_1/dense/", actual=net1.trainable_weights[1].name) self.assertStartsWith( - expected_start="first_network_1/my_network_1/dense_1/", + expected_start="first_network/my_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_network_1/my_network_1/dense_1/", + expected_start="second_network/my_network/dense/", actual=net2.trainable_weights[1].name) self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) - self.assertEqual("first_network_1", net1.name) - self.assertEqual("my_network_1", net1.first.name) - self.assertEqual("my_network_2", net1.second.name) + self.assertEqual("first_network", net1.name) + self.assertEqual("my_network", net1.first.name) + self.assertEqual("my_network_1", net1.second.name) self.assertTrue(net2.first is net1.first) - self.assertEqual("my_network_1", net2.second.name) + self.assertEqual("my_network", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testLayerCallInDifferentOrderThanConstruct(self): @@ -954,23 +1153,23 @@ class NetworkTest(test.TestCase): net1(one) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net1.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/dense_2/", + expected_start="first_network/dense_1/", actual=net1.trainable_weights[1].name) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_network_1/dense_1/", + expected_start="second_network/dense/", actual=net2.trainable_weights[1].name) self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) - self.assertEqual("first_network_1", net1.name) - self.assertEqual("dense_1", net1.first.name) - self.assertEqual("dense_2", net1.second.name) + self.assertEqual("first_network", net1.name) + self.assertEqual("dense", net1.first.name) + self.assertEqual("dense_1", net1.second.name) self.assertTrue(net2.first is net1.first) - self.assertEqual("dense_1", net2.second.name) + self.assertEqual("dense", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testLayerAlreadyBuilt(self): @@ -999,13 +1198,13 @@ class NetworkTest(test.TestCase): # do not match their layer names. actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net.trainable_weights[1].name) self.assertTrue( net.trainable_weights[0] is shared_layer.trainable_weights[0]) - self.assertEqual("first_network_1", net.name) + self.assertEqual("first_network", net.name) self.assertEqual("dense_3", net.first.name) - self.assertEqual("dense_1", net.second.name) + self.assertEqual("dense", net.second.name) class SequentialTest(test.TestCase): diff --git a/tensorflow/contrib/eager/python/summary_writer.py b/tensorflow/contrib/eager/python/summary_writer.py deleted file mode 100644 index 5d8c41b545b3c9fd03af85f302ba05a394f085a4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/summary_writer.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TensorBoard Summary Writer for TensorFlow Eager Execution.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import uuid - -from tensorflow.contrib.summary import gen_summary_ops -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_op_util -from tensorflow.python.ops import variable_scope - - -def _maybe_cpu(v): - if isinstance(v, (ops.EagerTensor, ops.Tensor)): - return v.cpu() - else: - return v - - -def _summary_writer_function(name, tensor, function, family=None): - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - function(tag, scope) - return True - return record - - -class SummaryWriter(object): - """Writes summaries for TensorBoard, compatible with eager execution. - - This class is the supported way of writing TensorBoard summaries under - eager execution. - """ - - _CPU_DEVICE = "cpu:0" - - def __init__(self, - logdir, - max_queue=10, - flush_secs=120, - filename_suffix=""): - """Summary writer for TensorBoard, compatible with eager execution. - - If necessary, multiple instances of `SummaryWriter` can be created, with - distinct `logdir`s and `name`s. Each `SummaryWriter` instance will retain - its independent `global_step` counter and data writing destination. - - Example: - ```python - writer = tfe.SummaryWriter("my_model") - - # ... Code that sets up the model and data batches ... - - for _ in xrange(train_iters): - loss = model.train_batch(batch) - writer.scalar("loss", loss) - writer.step() - ``` - - Args: - logdir: Directory in which summary files will be written. - max_queue: Number of summary items to buffer before flushing to - filesystem. If 0, summaries will be flushed immediately. - flush_secs: Number of secondsbetween forced commits to disk. - filename_suffix: Suffix of the event protobuf files in which the summary - data are stored. - - Raises: - ValueError: If this constructor is called not under eager execution. - """ - # TODO(apassos, ashankar): Make this class and the underlying - # contrib.summary_ops compatible with graph model and remove this check. - if not context.in_eager_mode(): - raise ValueError( - "Use of SummaryWriter is currently supported only with eager " - "execution enabled. File an issue at " - "https://github.com/tensorflow/tensorflow/issues/new to express " - "interest in fixing this.") - - # TODO(cais): Consider adding name keyword argument, which if None or empty, - # will register the global global_step that training_util.get_global_step() - # can find. - with context.device(self._CPU_DEVICE): - self._name = uuid.uuid4().hex - self._global_step = 0 - self._global_step_tensor = variable_scope.get_variable( - "global_step/summary_writer/" + self._name, - shape=[], dtype=dtypes.int64, - initializer=init_ops.zeros_initializer()) - self._global_step_dirty = False - self._resource = gen_summary_ops.summary_writer(shared_name=self._name) - gen_summary_ops.create_summary_file_writer( - self._resource, logdir, max_queue, flush_secs, filename_suffix) - # Delete the resource when this object is deleted - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device=self._CPU_DEVICE) - - def step(self): - """Increment the global step counter of this SummaryWriter instance.""" - self._global_step += 1 - self._global_step_dirty = True - - @property - def global_step(self): - """Obtain the current global_step value of this SummaryWriter instance. - - Returns: - An `int` representing the current value of the global_step of this - `SummaryWriter` instance. - """ - return self._global_step - - def _update_global_step_tensor(self): - with context.device(self._CPU_DEVICE): - if self._global_step_dirty: - self._global_step_dirty = False - return state_ops.assign(self._global_step_tensor, self._global_step) - else: - return self._global_step_tensor - - def generic(self, name, tensor, metadata, family=None): - """Write a generic-type summary. - - Args: - name: A name for the generated node. Will also serve as the series name in - TensorBoard. - tensor: A `Tensor` or compatible value type containing the value of the - summary. - metadata: Metadata about the summary. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_summary( - self._resource, - self._update_global_step_tensor(), - _maybe_cpu(tensor), - tag, - _maybe_cpu(metadata), - name=scope) - - def scalar(self, name, tensor, family=None): - """Write a scalar summary. - - Args: - name: A name for the generated node. Will also serve as the series name in - TensorBoard. - tensor: A real numeric `Tensor` or compatible value type containing a - single value. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - - Returns: - A summary writer function for scalars. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_scalar_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), name=scope) - - def histogram(self, name, tensor, family=None): - """Write a histogram summary. - - Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A real numeric `Tensor` or compatible value type. Any shape. - Values to use to build the histogram. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_histogram_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), name=scope) - - def image(self, name, tensor, bad_color=None, max_images=3, family=None): - """Write an image summary.""" - with context.device(self._CPU_DEVICE): - if bad_color is None: - bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_image_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), bad_color_, max_images, - name=scope) - - def audio(self, name, tensor, sample_rate, max_outputs, family=None): - """Write an audio summary. - - Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]` - or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`, or - compatible value type. - sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the - signal in hertz. - max_outputs: Max number of batch elements to generate audio for. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_audio_summary( - self._resource, self._update_global_step_tensor(), - tag, - _maybe_cpu(tensor), - sample_rate=_maybe_cpu(sample_rate), - max_outputs=max_outputs, - name=scope) diff --git a/tensorflow/contrib/eager/python/summary_writer_test.py b/tensorflow/contrib/eager/python/summary_writer_test.py deleted file mode 100644 index 5ebb36d04fcba8f4558fa1c09716314af42f559f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/summary_writer_test.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit tests for eager execution SummaryWriter.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import shutil -import tempfile - -import numpy as np - -from tensorflow.contrib.eager.python import summary_writer -from tensorflow.core.util import event_pb2 -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.lib.io import tf_record -from tensorflow.python.platform import gfile - - -class SummaryWriterTest(test.TestCase): - - def setUp(self): - super(SummaryWriterTest, self).setUp() - self._test_device = "gpu:0" if context.num_gpus() else "cpu:0" - self._tmp_logdir = tempfile.mkdtemp() - with context.device(self._test_device): - # Use max_queue=0 so that summaries are immediately flushed to filesystem, - # making testing easier. - self._writer = summary_writer.SummaryWriter(self._tmp_logdir, max_queue=0) - - def tearDown(self): - if os.path.isdir(self._tmp_logdir): - shutil.rmtree(self._tmp_logdir) - super(SummaryWriterTest, self).tearDown() - - def _readLastEvent(self, logdir=None): - if not logdir: - logdir = self._tmp_logdir - files = [f for f in gfile.ListDirectory(logdir) - if not gfile.IsDirectory(os.path.join(logdir, f))] - file_path = os.path.join(logdir, files[0]) - records = list(tf_record.tf_record_iterator(file_path)) - event = event_pb2.Event() - event.ParseFromString(records[-1]) - return event - - def testGlobalStep(self): - with context.device(self._test_device): - orig_step = self._writer.global_step - self._writer.step() - self.assertEqual(orig_step + 1, self._writer.global_step) - self.assertEqual(orig_step + 1, self._writer.global_step) - self._writer.step() - self._writer.step() - self.assertEqual(orig_step + 3, self._writer.global_step) - - def testGenericSummary(self): - with context.device(self._test_device): - x = constant_op.constant(1337.0) - with context.device("cpu:0"): - metadata = constant_op.constant("foo") - self._writer.generic("x", x, metadata) - event = self._readLastEvent() - self.assertEqual("x", event.summary.value[0].tag) - - def testScalarSummary(self): - with context.device(self._test_device): - x = constant_op.constant(1337.0) - self._writer.scalar("x", x) - event = self._readLastEvent() - self.assertTrue("x", event.summary.value[0].tag) - self.assertEqual(1337.0, event.summary.value[0].simple_value) - - def testHistogramSummary(self): - with context.device(self._test_device): - y = constant_op.constant([1.0, 3.0, 3.0, 7.0]) - self._writer.histogram("y", y) - event = self._readLastEvent() - self.assertEqual("y", event.summary.value[0].tag) - self.assertTrue(event.summary.value[0].histo) - - def testImageSummary(self): - with context.device(self._test_device): - a = constant_op.constant([[10.0, 20.0], [-20.0, -10.0]]) - self._writer.histogram("image1", a) - event = self._readLastEvent() - self.assertEqual("image1", event.summary.value[0].tag) - self.assertTrue(event.summary.value[0].image) - - def testAudioSummary(self): - with context.device(self._test_device): - w = constant_op.constant(np.random.rand(3, 10, 2), dtype=dtypes.float32) - fs = constant_op.constant(44100.0, dtype=dtypes.float32) - max_outputs = 1 - self._writer.audio("audio1", w, fs, max_outputs) - event = self._readLastEvent() - self.assertTrue(event.summary.value[0].audio) - - def testTwoSummaryWritersGlobalStepsWorkWithoutCrosstalk(self): - tmp_logdir2 = os.path.join(self._tmp_logdir, "_writer2_") - writer2 = summary_writer.SummaryWriter(tmp_logdir2, max_queue=0) - - self.assertEqual(0, writer2.global_step) - self._writer.step() - self.assertEqual(0, writer2.global_step) - writer2.step() - writer2.step() - writer2.step() - self.assertEqual(3, writer2.global_step) - - x = constant_op.constant(1337.0) - writer_orig_step = self._writer.global_step - self._writer.step() - self._writer.scalar("x", x) - - event = self._readLastEvent() - self.assertEqual(writer_orig_step + 1, event.step) - - writer2.scalar("x", x) - event = self._readLastEvent(tmp_logdir2) - self.assertEqual(3, event.step) - - self._writer.step() - self._writer.scalar("x", x) - - event = self._readLastEvent() - self.assertEqual(writer_orig_step + 2, event.step) - - -# TODO(cais): Add performance benchmark for SummaryWriter. - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index b6c687c82946ec62ccb90165791587dc335f13c7..770a7e3e7a01f3351c229b7fb53383240dd1f1c8 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -23,6 +23,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@list_devices @@num_gpus +@@py_func @@defun @@implicit_gradients @@implicit_value_and_gradients @@ -30,9 +31,6 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@value_and_gradients_function @@GradientTape -@@enable_tracing -@@flush_trace - @@run @@enable_eager_execution @@ -46,13 +44,16 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@seterr @@Iterator -@@Network @@Saver @@restore_variables_on_create @@Variable @@get_optimizer_variables @@EagerVariableStore +@@Network +@@save_network_checkpoint +@@restore_network_checkpoint + @@in_eager_mode @@in_graph_mode @@ -74,6 +75,8 @@ from __future__ import print_function from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.eager.python.datasets import Iterator from tensorflow.contrib.eager.python.network import Network +from tensorflow.contrib.eager.python.network import save_network_checkpoint +from tensorflow.contrib.eager.python.network import restore_network_checkpoint from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver @@ -86,7 +89,6 @@ from tensorflow.python.eager.context import in_eager_mode from tensorflow.python.eager.context import in_graph_mode from tensorflow.python.eager.context import list_devices from tensorflow.python.eager.context import num_gpus -from tensorflow.python.eager.core import enable_tracing from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks @@ -100,8 +102,10 @@ from tensorflow.python.framework.test_util import IsolateTest from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore +from tensorflow.python.ops import script_ops from tensorflow.python.util.all_util import remove_undocumented +py_func = script_ops.eager_py_func defun = function.defun implicit_gradients = backprop.implicit_grad implicit_value_and_gradients = backprop.implicit_val_and_grad diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 6eb2cfdaca7840c4a5dd8cffc9620aaf3f96a1de..ba272d7e885434eb556cbafd3d9e64a50d21f9b2 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -27,8 +27,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":dnn", + ":dnn_linear_combined", ":extenders", ":head", + ":linear", ":logit_fns", ":multi_head", ":replicate_model_fn", @@ -73,6 +75,46 @@ py_test( ], ) +py_library( + name = "dnn_linear_combined", + srcs = ["python/estimator/dnn_linear_combined.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:nn", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:dnn_linear_combined", + ], +) + +py_test( + name = "dnn_linear_combined_test", + size = "medium", + srcs = ["python/estimator/dnn_linear_combined_test.py"], + shard_count = 3, + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":dnn_linear_combined", + ":head", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:nn", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python/estimator:dnn_testing_utils", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:linear_testing_utils", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "extenders", srcs = [ @@ -169,6 +211,42 @@ py_test( ], ) +py_library( + name = "linear", + srcs = ["python/estimator/linear.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:linear", + ], +) + +py_test( + name = "linear_test", + size = "small", + srcs = ["python/estimator/linear_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":head", + ":linear", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:linear_testing_utils", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "logit_fns", srcs = [ @@ -204,10 +282,14 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:summary", "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/saved_model:signature_constants", "@six_archive//:six", @@ -229,7 +311,7 @@ py_test( "//tensorflow/python:string_ops", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/ops/losses", + "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/saved_model:signature_constants", "//third_party/py/numpy", "@six_archive//:six", @@ -265,7 +347,7 @@ py_library( cuda_py_test( name = "replicate_model_fn_test", - size = "small", + size = "medium", srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ "//tensorflow/python/estimator", @@ -293,5 +375,5 @@ cuda_py_test( "//tensorflow/python:variables", ":replicate_model_fn", ], - tags = ["requires-gpu-sm35"], + tags = ["multi_gpu"], ) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index cf727264cd5116915f6bd7f285e470cbc2e2742a..28c1f8b1809d27db697365b7bb50441f7820d2b4 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -20,10 +20,13 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import from tensorflow.contrib.estimator.python.estimator.dnn import * +from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * +from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * +from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import @@ -38,9 +41,12 @@ _allowed_symbols = [ 'multi_label_head', 'regression_head', 'DNNEstimator', + 'DNNLinearCombinedEstimator', + 'LinearEstimator', 'call_logit_fn', 'dnn_logit_fn_builder', 'linear_logit_fn_builder', + 'replicate_model_fn', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..ccaf1128bf23af734f7a5722a4dd8c1f0304fab7 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py @@ -0,0 +1,164 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow estimator for Linear and DNN joined training models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import dnn_linear_combined as dnn_linear_combined_lib +from tensorflow.python.ops import nn + + +class DNNLinearCombinedEstimator(estimator.Estimator): + """An estimator for TensorFlow Linear and DNN joined models with custom head. + + Note: This estimator is also known as wide-n-deep. + + Example: + + ```python + numeric_feature = numeric_column(...) + categorical_column_a = categorical_column_with_hash_bucket(...) + categorical_column_b = categorical_column_with_hash_bucket(...) + + categorical_feature_a_x_categorical_feature_b = crossed_column(...) + categorical_feature_a_emb = embedding_column( + categorical_column=categorical_feature_a, ...) + categorical_feature_b_emb = embedding_column( + categorical_column=categorical_feature_b, ...) + + estimator = DNNLinearCombinedEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + # wide settings + linear_feature_columns=[categorical_feature_a_x_categorical_feature_b], + linear_optimizer=tf.train.FtrlOptimizer(...), + # deep settings + dnn_feature_columns=[ + categorical_feature_a_emb, categorical_feature_b_emb, + numeric_feature], + dnn_hidden_units=[1000, 500, 100], + dnn_optimizer=tf.train.ProximalAdagradOptimizer(...)) + + # To apply L1 and L2 regularization, you can set optimizers as follows: + tf.train.ProximalAdagradOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=0.001) + # It is same for FtrlOptimizer. + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * for each `column` in `dnn_feature_columns` + `linear_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss is calculated by using mean squared error. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + head, + model_dir=None, + linear_feature_columns=None, + linear_optimizer='Ftrl', + dnn_feature_columns=None, + dnn_optimizer='Adagrad', + dnn_hidden_units=None, + dnn_activation_fn=nn.relu, + dnn_dropout=None, + input_layer_partitioner=None, + config=None): + """Initializes a DNNLinearCombinedEstimator instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + linear_feature_columns: An iterable containing all the feature columns + used by linear part of the model. All items in the set must be + instances of classes derived from `FeatureColumn`. + linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to + the linear part of the model. Defaults to FTRL optimizer. + dnn_feature_columns: An iterable containing all the feature columns used + by deep part of the model. All items in the set must be instances of + classes derived from `FeatureColumn`. + dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to + the deep part of the model. Defaults to Adagrad optimizer. + dnn_hidden_units: List of hidden units per layer. All layers are fully + connected. + dnn_activation_fn: Activation function applied to each layer. If None, + will use `tf.nn.relu`. + dnn_dropout: When not None, the probability we will drop out + a given coordinate. + input_layer_partitioner: Partitioner for input layer. Defaults to + `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: RunConfig object to configure the runtime settings. + + Raises: + ValueError: If both linear_feature_columns and dnn_features_columns are + empty at the same time. + """ + linear_feature_columns = linear_feature_columns or [] + dnn_feature_columns = dnn_feature_columns or [] + self._feature_columns = ( + list(linear_feature_columns) + list(dnn_feature_columns)) + if not self._feature_columns: + raise ValueError('Either linear_feature_columns or dnn_feature_columns ' + 'must be defined.') + + def _model_fn(features, labels, mode, config): + return dnn_linear_combined_lib._dnn_linear_combined_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + linear_feature_columns=linear_feature_columns, + linear_optimizer=linear_optimizer, + dnn_feature_columns=dnn_feature_columns, + dnn_optimizer=dnn_optimizer, + dnn_hidden_units=dnn_hidden_units, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + input_layer_partitioner=input_layer_partitioner, + config=config) + + super(DNNLinearCombinedEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e4d34dc70ccaa4806ae8b8ed5001bd971ee7b4 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py @@ -0,0 +1,220 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dnn_linear_combined.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import dnn_linear_combined +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.python.estimator.canned import dnn_testing_utils +from tensorflow.python.estimator.canned import linear_testing_utils +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache + + +def _dnn_only_estimator_fn( + hidden_units, + feature_columns, + model_dir=None, + label_dimension=1, + weight_column=None, + optimizer='Adagrad', + activation_fn=nn.relu, + dropout=None, + input_layer_partitioner=None, + config=None): + return dnn_linear_combined.DNNLinearCombinedEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension), + model_dir=model_dir, + dnn_feature_columns=feature_columns, + dnn_optimizer=optimizer, + dnn_hidden_units=hidden_units, + dnn_activation_fn=activation_fn, + dnn_dropout=dropout, + input_layer_partitioner=input_layer_partitioner, + config=config) + + +class DNNOnlyEstimatorEvaluateTest( + dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__( + self, _dnn_only_estimator_fn) + + +class DNNOnlyEstimatorPredictTest( + dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNRegressorPredictTest.__init__( + self, _dnn_only_estimator_fn) + + +class DNNOnlyEstimatorTrainTest( + dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNRegressorTrainTest.__init__( + self, _dnn_only_estimator_fn) + + +def _linear_only_estimator_fn( + feature_columns, + model_dir=None, + label_dimension=1, + weight_column=None, + optimizer='Ftrl', + config=None, + partitioner=None): + return dnn_linear_combined.DNNLinearCombinedEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension), + model_dir=model_dir, + linear_feature_columns=feature_columns, + linear_optimizer=optimizer, + input_layer_partitioner=partitioner, + config=config) + + +class LinearOnlyEstimatorEvaluateTest( + linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__( + self, _linear_only_estimator_fn) + + +class LinearOnlyEstimatorPredictTest( + linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorPredictTest.__init__( + self, _linear_only_estimator_fn) + + +class LinearOnlyEstimatorTrainTest( + linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorTrainingTest.__init__( + self, _linear_only_estimator_fn) + + +class DNNLinearCombinedEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, + label_dimension, batch_size): + linear_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,))] + dnn_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,))] + feature_columns = linear_feature_columns + dnn_feature_columns + est = dnn_linear_combined.DNNLinearCombinedEstimator( + head=head_lib.regression_head(label_dimension=label_dimension), + linear_feature_columns=linear_feature_columns, + dnn_feature_columns=dnn_feature_columns, + dnn_hidden_units=(2, 2), + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predictions = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in est.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + batch_size = 10 + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + # learn y = x + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=label_dimension, + label_dimension=label_dimension, + batch_size=batch_size) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index e344ee3c3eab22d217570a8c8073f72998e77b03..a9311a20f127d92f02a95b8b48082fc90850635a 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops @@ -48,7 +49,20 @@ def multi_class_head(n_classes, Uses `sparse_softmax_cross_entropy` loss. - This head expects to be fed integer labels specifying the class index. + The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. + In many applications, the shape is `[batch_size, n_classes]`. + + `labels` must be a dense `Tensor` with shape matching `logits`, namely + `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string + `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, + `labels` must be an integer `Tensor` with values specifying the class index. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. + + The loss is the weighted sum over the input dimensions. Namely, if the input + labels have shape `[batch_size, 1]`, the loss is the weighted sum over + `batch_size`. Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use @@ -57,11 +71,11 @@ def multi_class_head(n_classes, `tf.feature_column.numeric_column` defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. - label_vocabulary: A list of strings represents possible label values. If it - is not given, that means labels are already encoded as integer within - [0, n_classes). If given, labels must be string type and have any value in - `label_vocabulary`. Also there will be errors if vocabulary is not - provided and labels are string. + label_vocabulary: A list or tuple of strings representing possible label + values. If it is not given, that means labels are already encoded as an + integer within [0, n_classes). If given, labels must be of string type and + have any value in `label_vocabulary`. Note that errors will be raised if + `label_vocabulary` is not provided but labels are strings. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -84,7 +98,20 @@ def binary_classification_head( This head uses `sigmoid_cross_entropy_with_logits` loss. - This head expects to be fed float labels of shape `(batch_size, 1)`. + The head expects `logits` with shape `[D0, D1, ... DN, 1]`. + In many applications, the shape is `[batch_size, 1]`. + + `labels` must be a dense `Tensor` with shape matching `logits`, namely + `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string + `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, + `labels` must be float `Tensor` with values in the interval `[0, 1]`. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. + + The loss is the weighted sum over the input dimensions. Namely, if the input + labels have shape `[batch_size, 1]`, the loss is the weighted sum over + `batch_size`. Args: weight_column: A string or a `_NumericColumn` created by @@ -96,11 +123,11 @@ def binary_classification_head( generated for each threshold value. This threshold is applied to the logistic values to determine the binary classification (i.e., above the threshold is `true`, below is `false`. - label_vocabulary: A list of strings represents possible label values. If it - is not given, that means labels are already encoded within [0, 1]. If - given, labels must be string type and have any value in - `label_vocabulary`. Also there will be errors if vocabulary is not - provided and labels are string. + label_vocabulary: A list or tuple of strings representing possible label + values. If it is not given, labels must be float with values within + [0, 1]. If given, labels must be string type and have any value in + `label_vocabulary`. Note that errors will be raised if `label_vocabulary` + is not provided but labels are strings. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -120,9 +147,22 @@ def binary_classification_head( def regression_head(weight_column=None, label_dimension=1, name=None): - """Creates a `_Head` for regression using the mean squared loss. + """Creates a `_Head` for regression using the `mean_squared_error` loss. + + The loss is the weighted sum over all input dimensions. Namely, if the input + labels have shape `[batch_size, label_dimension]`, the loss is the weighted + sum over both `batch_size` and `label_dimension`. + + The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`. + In many applications, the shape is `[batch_size, label_dimension]`. + + The `labels` shape must match `logits`, namely + `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape + `[D0, D1, ... DN]` is also supported. - Uses `mean_squared_error` loss. + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or + `[D0, D1, ... DN, label_dimension]`. Args: weight_column: A string or a `_NumericColumn` created by @@ -156,15 +196,29 @@ def multi_label_head(n_classes, or more associated labels, from a discrete set. This is distinct from `multi_class_head` which has exactly one label per example. - Uses `sigmoid_cross_entropy` loss averaged over classes. Expects labels as a - multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer - `SparseTensor` of class indices. + Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over + the batch. Namely, if the input logits have shape `[batch_size, n_classes]`, + the loss is the average over `n_classes` and the weighted sum over + `batch_size`. + + The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many + applications, the shape is `[batch_size, label_n_classes]`. + + Labels can be: + * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` + * An integer `SparseTensor` of class indices. The `dense_shape` must be + `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. + * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` + must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with - shape `[batch_size, 1]`. `loss_fn` must support indicator `labels` with shape - `[batch_size, n_classes]`. Namely, the head applies `label_vocabulary` to the - input labels before passing them to `loss_fn`. + shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with + shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies + `label_vocabulary` to the input labels before passing them to `loss_fn`. Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use @@ -191,7 +245,7 @@ def multi_label_head(n_classes, An instance of `_Head` for multi-label classification. Raises: - ValueError: if `n_classes` or `thresholds` is invalid. + ValueError: if `n_classes`, `thresholds`, or `loss_fn` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: @@ -259,26 +313,36 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access indices=labels.indices, values=label_ids_values, dense_shape=labels.dense_shape) + return math_ops.to_int64( + sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) else: - label_ids = labels - return math_ops.to_int64( - sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) - msg = ('labels shape must be [batch_size, {}]. ' - 'Given: ').format(self._n_classes) - labels_shape = array_ops.shape(labels) - check_rank_op = control_flow_ops.Assert( - math_ops.equal(array_ops.rank(labels), 2), - data=[msg, labels_shape]) - check_label_dim = control_flow_ops.Assert( - math_ops.equal(labels_shape[-1], self._n_classes), - data=[msg, labels_shape]) - with ops.control_dependencies([check_rank_op, check_label_dim]): - return array_ops.identity(labels) + err_msg = ( + r'labels must be an integer SparseTensor with values in ' + r'[0, {})'.format(self._n_classes)) + assert_int = check_ops.assert_integer( + labels.values, message=err_msg) + assert_less = check_ops.assert_less( + labels.values, + ops.convert_to_tensor(self._n_classes, dtype=labels.dtype), + message=err_msg) + assert_greater = check_ops.assert_non_negative( + labels.values, message=err_msg) + with ops.control_dependencies( + [assert_int, assert_less, assert_greater]): + return math_ops.to_int64( + sparse_ops.sparse_to_indicator(labels, self._n_classes)) + err_msg = ( + r'labels must be an integer indicator Tensor with values in [0, 1]') + return head_lib._assert_range(labels, 2, message=err_msg) # pylint:disable=protected-access, def create_loss(self, features, mode, logits, labels): """See `Head`.""" del mode # Unused for this head. + logits = ops.convert_to_tensor(logits) processed_labels = self._process_labels(labels) + processed_labels = head_lib._check_dense_labels_match_logits_and_reshape( # pylint:disable=protected-access + labels=processed_labels, logits=logits, + expected_labels_dimension=self.logits_dimension) if self._loss_fn: unweighted_loss = _call_loss_fn( loss_fn=self._loss_fn, labels=processed_labels, logits=logits, @@ -290,7 +354,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access # Averages loss over classes. unweighted_loss = math_ops.reduce_mean( unweighted_loss, axis=-1, keep_dims=True) - weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access, + weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, + features=features, weight_column=self._weight_column, logits=logits) weighted_sum_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) # _weights() can return 1. @@ -305,7 +370,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access self, features, mode, logits, labels=None, train_op_fn=None): """See `Head`.""" with ops.name_scope(self._name, 'head'): - logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access + logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access # Predict. pred_keys = prediction_keys.PredictionKeys @@ -335,6 +400,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access # Eval. if mode == model_fn.ModeKeys.EVAL: + weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, + features=features, weight_column=self._weight_column, logits=logits) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, @@ -342,7 +409,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access eval_metric_ops=self._eval_metric_ops( labels=processed_labels, probabilities=probabilities, - weights=head_lib._weights(features, self._weight_column), # pylint:disable=protected-access, + weights=weights, weighted_sum_loss=weighted_sum_loss, example_weight_sum=example_weight_sum)) diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index fd8c53f6a94bf741c02e814ca96bfcea050589c4..d1cf9090048470181818c573647923c9f5824dfa 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -316,13 +316,14 @@ class MultiLabelHead(test.TestCase): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, - r'labels shape must be \[batch_size, 2\]\. Given: \] \[2 1\]'): + r'\[expected_labels_shape: \] \[2 2\] \[labels_shape: \] \[2 1\]'): actual_weighted_sum_loss.eval({ labels_placeholder: np.array([[1], [1]], dtype=np.int64) }) with self.assertRaisesRegexp( errors.InvalidArgumentError, - r'labels shape must be \[batch_size, 2\]\. Given: \] \[2\]'): + r'labels shape must be \[D0, D1, ... DN, 2\]\..*' + r'\[Received shape: \] \[2\]'): actual_weighted_sum_loss.eval({ labels_placeholder: np.array([1, 1], dtype=np.int64) }) @@ -387,9 +388,11 @@ class MultiLabelHead(test.TestCase): logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), labels=None) - def _test_eval(self, head, logits, labels, expected_loss, expected_metrics): + def _test_eval( + self, head, logits, labels, expected_loss, expected_metrics, + features=None): spec = head.create_estimator_spec( - features={'x': np.array(((42,),), dtype=np.int32)}, + features=features or {}, mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels) @@ -655,6 +658,54 @@ class MultiLabelHead(test.TestCase): labels=None, train_op_fn=_no_op_train_fn) + def test_train_invalid_indicator_labels(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + # The value 2 is outside the allowed range. + labels = np.array([[2, 0], [1, 1]], dtype=np.int64) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'labels must be an integer indicator Tensor with values in ' + r'\[0, 1\]'): + sess.run(spec.loss) + + def test_train_invalid_sparse_labels(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + # The value 2 is outside the allowed range. + labels = sparse_tensor.SparseTensor( + values=[2, 0, 1], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'labels must be an integer SparseTensor with values in \[0, 2\)'): + sess.run(spec.loss) + def _test_train(self, head, logits, labels, expected_loss): expected_train_result = 'my_train_op' def _train_op_fn(loss): @@ -791,6 +842,153 @@ class MultiLabelHead(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3, }, summary_str, tol) + def test_multi_dim_weighted_train_create_loss(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) + # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 + # = [[20/3, 10/3], [4, 8]] + # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 + expected_weighted_sum_loss = 39.6667 + expected_example_weight_sum = np.sum(weights) + actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss( + features={'weights': weights}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + atol = 1.e-3 + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + self.assertAllClose( + expected_weighted_sum_loss, actual_weighted_sum_loss.eval(), + atol=atol) + self.assertAllClose( + expected_example_weight_sum, actual_example_weight_sum.eval(), + atol=atol) + + def test_multi_dim_weighted_train(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) + # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 + # = [[20/3, 10/3], [4, 8]] + # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 + expected_loss = 39.6667 + expected_train_result = 'my_train_op' + def _train_op_fn(loss): + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = head.create_estimator_spec( + features={'weights': weights}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + atol = 1.e-3 + with self.test_session() as sess: + _initialize_variables(self, monitored_session.Scaffold()) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, atol=atol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + + def test_multi_dim_weights_wrong_inner_dim(self): + """Logits and labels of shape [2, 2, 3], weights [2, 1].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1.], [2.]], dtype=np.float32) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={'weights': weights}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 1\]'): + spec.loss.eval() + + def test_multi_dim_weights_wrong_outer_dim(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2, 3].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[[1., 1., 1.], [1.5, 1.5, 1.5]], + [[2., 2., 2.], [2.5, 2.5, 2.5]]], dtype=np.float32) + weights_placeholder = array_ops.placeholder(dtype=dtypes.float32) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={'weights': weights_placeholder}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 2 3\]'): + spec.loss.eval({weights_placeholder: weights}) + + def test_multi_dim_weighted_eval(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) + # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 + # = [[20/3, 10/3], [4, 8]] + # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 + expected_loss = 39.6667 + keys = metric_keys.MetricKeys + expected_metrics = { + keys.LOSS_MEAN: expected_loss / np.sum(weights), + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.4977, + keys.AUC_PR: 0.6645, + } + self._test_eval( + head=head, + features={'weights': weights}, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf4abe83d54504d55de73b63f369cceaf149dd2 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/linear.py @@ -0,0 +1,118 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import linear as linear_lib + + +class LinearEstimator(estimator.Estimator): + """An estimator for TensorFlow linear models with user-specified head. + + Example: + + ```python + categorical_column_a = categorical_column_with_hash_bucket(...) + categorical_column_b = categorical_column_with_hash_bucket(...) + + categorical_feature_a_x_categorical_feature_b = crossed_column(...) + + # Estimator using the default optimizer. + estimator = LinearEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b]) + + # Or estimator using the FTRL optimizer with regularization. + estimator = LinearEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b]) + optimizer=tf.train.FtrlOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001 + )) + + def input_fn_train: # returns x, y (where y represents label's class index). + ... + estimator.train(input_fn=input_fn_train, steps=100) + def input_fn_eval: # returns x, y (where y represents label's class index). + ... + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + ... + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + * for each `column` in `feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss and predicted output are determined by the specified head. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + head, + feature_columns, + model_dir=None, + optimizer='Ftrl', + config=None, + partitioner=None): + """Initializes a `LinearEstimator` instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + optimizer: An instance of `tf.Optimizer` used to train the model. Defaults + to FTRL optimizer. + config: `RunConfig` object to configure the runtime settings. + partitioner: Optional. Partitioner for input layer. + """ + def _model_fn(features, labels, mode, config): + return linear_lib._linear_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + feature_columns=tuple(feature_columns or []), + optimizer=optimizer, + partitioner=partitioner, + config=config) + super(LinearEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/linear_test.py b/tensorflow/contrib/estimator/python/estimator/linear_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c63514eb688af48577f0a3b7ce9e7478309f2c30 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/linear_test.py @@ -0,0 +1,153 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for linear.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.contrib.estimator.python.estimator import linear +from tensorflow.python.estimator.canned import linear_testing_utils +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache + + +def _linear_estimator_fn( + weight_column=None, label_dimension=1, *args, **kwargs): + """Returns a LinearEstimator that uses regression_head.""" + return linear.LinearEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension), + *args, **kwargs) + + +class LinearEstimatorEvaluateTest( + linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__( + self, _linear_estimator_fn) + + +class LinearEstimatorPredictTest( + linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorPredictTest.__init__( + self, _linear_estimator_fn) + + +class LinearEstimatorTrainTest( + linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorTrainingTest.__init__( + self, _linear_estimator_fn) + + +class LinearEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, + label_dimension, batch_size): + feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,))] + est = linear.LinearEstimator( + head=head_lib.regression_head(label_dimension=label_dimension), + feature_columns=feature_columns, + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predictions = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in est.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + batch_size = 10 + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + # learn y = x + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=label_dimension, + label_dimension=label_dimension, + batch_size=batch_size) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index 69dbfcee62af526cc92f8699f7137acbcdc03052..f2a6eae03ec021e5c28d48b3887870d8a057e077 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -22,10 +22,14 @@ import six from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.saved_model import signature_constants +from tensorflow.python.summary import summary _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -72,6 +76,23 @@ def multi_head(heads, head_weights=None): estimator.train(input_fn=input_fn, steps=100) ``` + Also supports `logits` as a `Tensor` of shape + `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the + last dimension and distribute it appropriately among the heads. E.g.: + + ```python + def model_fn(features, labels, mode): + # Create simple heads and specify head name. + head1 = multi_class_head(n_classes=3, name='head1') + head2 = binary_classification_head(name='head2') + # Create multi-head from two simple heads. + head = multi_head([head1, head2]) + # Create logits for the multihead. + logits = logit_fn(logits_dimension=head.logits_dimension) + # Return the merged EstimatorSpec + return head.create_estimator_spec(..., logits=logits, ...) + ``` + Args: heads: List or tuple of `_Head` instances. All heads must have `name` specified. The first head in the list is the default used at serving time. @@ -161,18 +182,17 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access def create_loss(self, features, mode, logits, labels): """See `Head`.""" - # TODO(roumposg): Add support for logits as single Tensor (with - # _split_logits utility). - if not isinstance(logits, dict): - raise ValueError('logits must be a dict. Single Tensor support coming ' - 'soon.') + if isinstance(logits, dict): + logits_dict = logits + else: + logits_dict = self._split_logits(logits) weighted_sum_losses = [] example_weight_sums = [] labels_by_head = {} for head in self._heads: (weighted_sum_loss, example_weight_sum, processed_labels) = head.create_loss( - features, mode, logits[head.name], labels[head.name]) + features, mode, logits_dict[head.name], labels[head.name]) weighted_sum_losses.append(weighted_sum_loss) example_weight_sums.append(example_weight_sum) labels_by_head[head.name] = processed_labels @@ -205,10 +225,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): """See `_Head`.""" - # TODO(roumposg): Add support for logits as single Tensor (with - # _split_logits utility). - if not isinstance(logits, dict): - raise ValueError('logits must be a dict. Given: {}'.format(logits)) + if isinstance(logits, dict): + logits_dict = logits + else: + logits_dict = self._split_logits(logits) if labels and not isinstance(labels, dict): raise ValueError('labels must be a dict. Given: {}'.format(labels)) @@ -219,22 +239,42 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access head.create_estimator_spec( features=features, mode=mode, - logits=logits[head_name], + logits=logits_dict[head_name], labels=labels[head_name] if labels else None, train_op_fn=_no_op_train_fn)) - # TODO(roumposg): Add LOSS and LOSS_MEAN summaries for the total head- - # combined loss. if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError('train_op_fn can not be None in TRAIN mode.') - return self._merge_train(all_estimator_spec, train_op_fn) + spec = self._merge_train(all_estimator_spec, train_op_fn) + with ops.name_scope(''): + summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) + return spec if mode == model_fn.ModeKeys.PREDICT: return self._merge_predict(all_estimator_spec) if mode == model_fn.ModeKeys.EVAL: return self._merge_eval(all_estimator_spec) raise ValueError('mode={} unrecognized'.format(mode)) + def _split_logits(self, logits): + """Splits logits along the last dimension and returns a dict.""" + logits_dict = {} + with ops.name_scope(None, 'split_logits', values=[logits]): + logits = ops.convert_to_tensor(logits) + batch_shape = array_ops.shape(logits)[:-1] + zeros_like_batch_shape = array_ops.zeros_like(batch_shape) + minus_ones_like_batch_shape = -1 * array_ops.ones_like(batch_shape) + begin_idx = 0 + for head in self._heads: + begin_tensor = array_ops.concat( + [zeros_like_batch_shape, [begin_idx]], axis=0) + size_tensor = array_ops.concat( + [minus_ones_like_batch_shape, [head.logits_dimension]], axis=0) + logits_dict[head.name] = array_ops.slice( + logits, begin=begin_tensor, size=size_tensor) + begin_idx += head.logits_dimension + return logits_dict + def _merge_train(self, all_estimator_spec, train_op_fn): """Merges list of `EstimatorSpec` for training. @@ -303,14 +343,19 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access predictions = {} metrics = {} losses = [] - for head, spec in zip(self._heads, all_estimator_spec): - losses.append(spec.loss) - head_name = head.name - # Metric keys already contain head.name. - metrics.update(spec.eval_metric_ops or {}) - for k, v in six.iteritems(spec.predictions): - predictions[(head_name, k)] = v - loss = _merge_losses(losses, self._head_weights) + with ops.name_scope('merge_eval'): + for head, spec in zip(self._heads, all_estimator_spec): + losses.append(spec.loss) + head_name = head.name + # Loss metric is not added by default. + loss_name = head_lib._summary_key( # pylint:disable=protected-access + head_name, metric_keys.MetricKeys.LOSS) + metrics[loss_name] = metrics_lib.mean(spec.loss, name=loss_name) + # Metric keys already contain head.name. + metrics.update(spec.eval_metric_ops or {}) + for k, v in six.iteritems(spec.predictions): + predictions[(head_name, k)] = v + loss = _merge_losses(losses, self._head_weights) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 16177aebd53cbff5c8fd727477ac5d18c9f8bce5..68f2d5d1cd53456f7dd82222e171b3619052321a 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -106,7 +106,8 @@ class MultiHeadTest(test.TestCase): multi_head = multi_head_lib.multi_head([head1, head2]) self.assertEqual('head1_head2', multi_head.name) - def test_predict_two_heads(self): + def test_predict_two_heads_logits_dict(self): + """Tests predict with logits as dict.""" head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') multi_head = multi_head_lib.multi_head([head1, head2]) @@ -158,6 +159,111 @@ class MultiHeadTest(test.TestCase): expected_probabilities['head2'], sess.run(spec.export_outputs['head2'].scores)) + def test_predict_two_heads_logits_tensor(self): + """Tests predict with logits as Tensor.""" + head1 = head_lib.multi_label_head(n_classes=2, name='head1') + head2 = head_lib.multi_label_head(n_classes=3, name='head2') + multi_head = multi_head_lib.multi_head([head1, head2]) + + logits = np.array( + [[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]], dtype=np.float32) + expected_logits1 = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) + expected_logits2 = np.array([[2., -2., 2.], [-3., 2., -2.]], + dtype=np.float32) + expected_probabilities = { + 'head1': _sigmoid(expected_logits1), + 'head2': _sigmoid(expected_logits2), + } + + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + self.assertItemsEqual( + (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', + 'head2', 'classification/head2', 'predict/head2'), + spec.export_outputs.keys()) + + # Assert predictions and export_outputs. + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNone(spec.scaffold.summary_op) + predictions = sess.run(spec.predictions) + self.assertAllClose( + expected_logits1, + predictions[('head1', prediction_keys.PredictionKeys.LOGITS)]) + self.assertAllClose( + expected_logits2, + predictions[('head2', prediction_keys.PredictionKeys.LOGITS)]) + self.assertAllClose( + expected_probabilities['head1'], + predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)]) + self.assertAllClose( + expected_probabilities['head2'], + predictions[('head2', prediction_keys.PredictionKeys.PROBABILITIES)]) + + self.assertAllClose( + expected_probabilities['head1'], + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores)) + self.assertAllClose( + expected_probabilities['head1'], + sess.run(spec.export_outputs['head1'].scores)) + self.assertAllClose( + expected_probabilities['head2'], + sess.run(spec.export_outputs['head2'].scores)) + + def test_predict_two_heads_logits_tensor_multi_dim(self): + """Tests predict with multi-dimensional logits of shape [2, 2, 5].""" + head1 = head_lib.regression_head(label_dimension=2, name='head1') + head2 = head_lib.regression_head(label_dimension=3, name='head2') + multi_head = multi_head_lib.multi_head([head1, head2]) + + logits = np.array( + [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]], + [[-1.5, 1., -3., 2., -2.], [-1.5, 1., -3., 2., -2.]]], + dtype=np.float32) + expected_logits1 = np.array( + [[[-1., 1.], [-1., 1.]], + [[-1.5, 1.], [-1.5, 1.]]], + dtype=np.float32) + expected_logits2 = np.array( + [[[2., -2., 2.], [2., -2., 2.]], + [[-3., 2., -2.], [-3., 2., -2.]]], + dtype=np.float32) + + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + self.assertItemsEqual( + (_DEFAULT_SERVING_KEY, 'head1', 'regression/head1', 'predict/head1', + 'head2', 'regression/head2', 'predict/head2'), + spec.export_outputs.keys()) + + # Assert predictions and export_outputs. + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNone(spec.scaffold.summary_op) + predictions = sess.run(spec.predictions) + self.assertAllClose( + expected_logits1, + predictions[('head1', prediction_keys.PredictionKeys.PREDICTIONS)]) + self.assertAllClose( + expected_logits2, + predictions[('head2', prediction_keys.PredictionKeys.PREDICTIONS)]) + + self.assertAllClose( + expected_logits1, + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].value)) + self.assertAllClose( + expected_logits1, + sess.run(spec.export_outputs['head1'].value)) + self.assertAllClose( + expected_logits2, + sess.run(spec.export_outputs['head2'].value)) + def test_eval_two_heads_with_weights(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') @@ -191,6 +297,8 @@ class MultiHeadTest(test.TestCase): keys = metric_keys.MetricKeys expected_metrics = { + keys.LOSS + '/head1': expected_loss_head1, + keys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, @@ -284,6 +392,84 @@ class MultiHeadTest(test.TestCase): # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13 self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol) + def test_train_create_loss_logits_tensor(self): + """Tests create_loss with logits Tensor.""" + weights1 = np.array([[1.], [2.]], dtype=np.float32) + weights2 = np.array([[2.], [3.]]) + head1 = head_lib.multi_label_head(n_classes=2, name='head1', + weight_column='weights1') + head2 = head_lib.multi_label_head(n_classes=3, name='head2', + weight_column='weights2') + multi_head = multi_head_lib.multi_head( + [head1, head2], head_weights=[1., 2.]) + + logits = np.array([[-10., 10., 20., -20., 20.], + [-15., 10., -30., 20., -20.]], dtype=np.float32) + labels = { + 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), + 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), + } + weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + features={ + 'x': np.array(((42,),), dtype=np.int32), + 'weights1': weights1, + 'weights2': weights2 + }, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + tol = 1e-3 + with self.test_session(): + # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] + # = [10, 7.5] + # weighted_sum_loss = 1 * 10 + 2 * 7.5 = 25 + # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] + # = [20, 10] + # weighted_sum_loss = 2 * 20 + 3 * 10 = 70 + # head-weighted merge = 1 * 25 + 2 * 70 = 165 + self.assertAllClose(165, weighted_sum_loss.eval(), rtol=tol, atol=tol) + # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13 + self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol) + + def test_train_create_loss_logits_tensor_multi_dim(self): + """Tests create_loss with multi-dimensional logits of shape [2, 2, 5].""" + head1 = head_lib.regression_head(label_dimension=2, name='head1') + head2 = head_lib.regression_head(label_dimension=3, name='head2') + multi_head = multi_head_lib.multi_head([head1, head2]) + + logits = np.array( + [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]], + [[-1.5, 1.5, -2., 2., -2.], [-1.5, 1.5, -2., 2., -2.]]], + dtype=np.float32) + labels = { + 'head1': np.array([[[1., 0.], [1., 0.]], + [[1.5, 1.5], [1.5, 1.5]]], dtype=np.float32), + 'head2': np.array([[[0., 1., 0.], [0., 1., 0.]], + [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32), + } + # Loss for the first head: + # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + + # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2 + # = 28 + # Loss for the second head: + # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + + # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2 + # = 74 + expected_weighted_sum_loss = 28. + 74. + + weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + features={}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + tol = 1e-3 + with self.test_session(): + self.assertAllClose( + expected_weighted_sum_loss, weighted_sum_loss.eval(), + rtol=tol, atol=tol) + self.assertAllClose( + 2. * 2. * 5., example_weight_sum.eval(), rtol=tol, atol=tol) + def test_train_one_head(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') multi_head = multi_head_lib.multi_head([head1]) @@ -327,6 +513,7 @@ class MultiHeadTest(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) _assert_simple_summaries(self, { + metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss, # Average loss over examples. metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2, @@ -387,6 +574,7 @@ class MultiHeadTest(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) _assert_simple_summaries(self, { + metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1, metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index 7005a647db599dfa386f34406911febe1d9d5651..ca3a2394ee227f2ab78e6d4d3d882f2b10954699 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -34,14 +34,15 @@ from tensorflow.python.estimator import util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import device as framework_device from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gradients as gradients_lib from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import tf_logging +from tensorflow.python.training import device_setter as device_setter_lib from tensorflow.python.training import training_util @@ -109,7 +110,7 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): - For all other fields of `EstimatorSpec` the values of the first tower are taken. - On replication of variables: + On distribution of variables: Variables are not duplicated between towers. Instead, they are placed on a single device as defined above and shared across towers. @@ -133,20 +134,65 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): conforms to the requirements of `Estimator`'s `model_fn` and can be used instead of the supplied `model_fn`. """ + return _replicate_model_fn_with_mode( + model_fn, + optimizer_fn, + devices, + # TODO(isaprykin): Query system configuration to choose modes other than + # `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often appropriate. + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER) + + +class _VariableDistributionMode(object): + """Modes for variable distribution used for forcing a particular one. + + Forcing a mode is meant for performance experimentation purposes rather than + for general use cases. + """ + + SHARED_LOCAL_PARAMETER_SERVER = 1 + """Variables are placed on a single device and shared across all devices. + + Two ways to achieve this distribution over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + """ + + SHARED_ROUND_ROBIN = 2 + """Variables are placed on all devices in a round-robin fashion. + + Every subsequent variable is placed on the next device. There is only one + copy of each variable that is shared across all devices. + """ + + +def _replicate_model_fn_with_mode( + model_fn, + optimizer_fn, + devices=None, + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER): + """A version of `replicate_model_fn` that allows to specify a `mode`.""" if not devices: devices = _get_local_devices('GPU') or _get_local_devices('CPU') is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0] - local_ps_device = '/{}:0'.format('GPU' if is_a_single_gpu_case else 'CPU') + consolidation_device = '/{}:0'.format('GPU' + if is_a_single_gpu_case else 'CPU') + + ps_devices = [consolidation_device] + if mode == _VariableDistributionMode.SHARED_ROUND_ROBIN: + ps_devices = devices - tf_logging.info('Replicating the `model_fn` across {}. Local parameter ' - 'server device is going to be {}.'.format( - devices, local_ps_device)) + tf_logging.info('Replicating the `model_fn` across {}. Variables are going ' + 'to be placed on {}. Consolidation device is going to be {}.' + .format(devices, ps_devices, consolidation_device)) - def replicated_model_fn(mode, features, labels, params=None, config=None): + def replicated_model_fn(features, labels, mode, params=None, config=None): """Replicated version of `model_fn` to be used instead.""" feature_shards, label_shards = _split_batch( - features, labels, len(devices), device=local_ps_device) + features, labels, len(devices), device=consolidation_device) tower_specs = _get_loss_towers( model_fn=model_fn, mode=mode, @@ -155,17 +201,17 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): params=params, config=config, devices=devices, - local_ps_device=local_ps_device) + local_ps_devices=ps_devices) if mode == model_fn_lib.ModeKeys.TRAIN: train_op = _minimize_towers(tower_specs, _call_optimizer_fn(optimizer_fn, params)) return _train_spec( - tower_specs, train_op, aggregation_device=local_ps_device) + tower_specs, train_op, aggregation_device=consolidation_device) elif mode == model_fn_lib.ModeKeys.EVAL: - return _eval_spec(tower_specs, aggregation_device=local_ps_device) + return _eval_spec(tower_specs, aggregation_device=consolidation_device) elif mode == model_fn_lib.ModeKeys.PREDICT: - return _predict_spec(tower_specs, aggregation_device=local_ps_device) + return _predict_spec(tower_specs, aggregation_device=consolidation_device) return replicated_model_fn @@ -183,10 +229,17 @@ def _split_batch(features, labels, number_of_shards, device): """Split input features and labes into batches.""" def split_dictionary(dictionary): + """Split a dictionary into shards.""" shards = [{} for _ in range(number_of_shards)] for name, tensor in six.iteritems(dictionary): - for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): - shards[i][name] = shard + if isinstance(tensor, sparse_tensor.SparseTensor): + for i, shard in enumerate( + sparse_ops.sparse_split( + sp_input=tensor, num_split=number_of_shards, axis=0)): + shards[i][name] = shard + else: + for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): + shards[i][name] = shard return shards with ops_lib.name_scope('split_inputs'): @@ -215,7 +268,7 @@ def _get_loss_towers(model_fn, params, config, devices, - local_ps_device, + local_ps_devices, name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): """Replicate the loss computation across devices.""" tower_specs = [] @@ -227,15 +280,22 @@ def _get_loss_towers(model_fn, if 'config' in model_fn_args: optional_params['config'] = copy.deepcopy(config) + # pylint: disable=protected-access + round_robin_strategy = device_setter_lib._RoundRobinStrategy( + num_tasks=len(local_ps_devices)) + # pylint: enable=protected-access + for i, device in enumerate(devices): is_the_first_tower = (i == 0) device_setter = _local_device_setter( - worker_device=device, ps_device=local_ps_device) + worker_device=device, + ps_devices=local_ps_devices, + ps_strategy=round_robin_strategy) - # We would like to preserve the names of the variables and ops that a user - # might be relying on. Names with prefix are going to resolve to variables - # and ops of the first tower. + # We would like to preserve the names of the variables and ops that the user + # might be relying on. Names without a prefix are going to resolve to + # variables and ops of the first tower. name_scope = name_scope_pattern if is_the_first_tower: name_scope = '' @@ -256,7 +316,7 @@ def _get_loss_towers(model_fn, return tower_specs -def _local_device_setter(ps_device, worker_device): +def _local_device_setter(worker_device, ps_devices, ps_strategy): """A device setter that puts distributes Var/Ops to PS/workers.""" ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] @@ -266,7 +326,7 @@ def _local_device_setter(ps_device, worker_device): node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def if node_def.op in ps_ops: ps_device_spec = framework_device.DeviceSpec.from_string( - '{}'.format(ps_device)) + '{}'.format(ps_devices[ps_strategy(op)])) ps_device_spec.merge_from(current_device) return ps_device_spec.to_string() @@ -284,10 +344,7 @@ def _minimize_towers(tower_specs, optimizer): grad_lists = {} for tower_spec in tower_specs: with ops_lib.device(tower_spec.loss.device): - variables = variables_lib.trainable_variables() - gradients = gradients_lib.gradients(tower_spec.loss, variables) - - for var, grad in zip(variables, gradients): + for grad, var in optimizer.compute_gradients(tower_spec.loss): if grad is not None: grad_lists.setdefault(var, []).append(grad) @@ -313,7 +370,17 @@ def _call_optimizer_fn(optimizer_fn, params): def _compute_sum_on_device(values, device, name=None): with ops_lib.device(device): - return math_ops.add_n(values, name=name) + if isinstance(values[0], ops_lib.IndexedSlices): + if name: + raise ValueError('The name {} is not expected to be given to ' + 'IndexedSlices {}'.format(name, values)) + + values_concat = array_ops.concat([v.values for v in values], axis=0) + indices_concat = array_ops.concat([v.indices for v in values], axis=0) + return ops_lib.IndexedSlices(values_concat, indices_concat, + values[0].dense_shape) + else: + return math_ops.add_n(values, name=name) def _train_spec(tower_specs, @@ -338,25 +405,17 @@ def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): [spec.loss for spec in tower_specs], aggregation_device, aggregated_loss_name) - eval_metric_ops_lists = {} + update_ops = [] for tower_spec in tower_specs: - metrics = tower_spec.eval_metric_ops or {} - for name, (_, update_op) in six.iteritems(metrics): - update_ops = eval_metric_ops_lists.setdefault(name, ([])) + for name, (_, update_op) in six.iteritems(tower_spec.eval_metric_ops): update_ops.append(update_op) + with ops_lib.control_dependencies(update_ops): + reduced_update_op = _reduce_metric_variables(len(tower_specs)) + eval_metric_ops = {} for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops): - with ops_lib.control_dependencies(eval_metric_ops_lists[name]): - # This operation reduces local variables across all metrics, yet is - # called for every metric. This is redundant and it's done because - # it is hard to know what local variables correspond to what metric. - # Estimator is going to execute all `reduced_update_op`s as part of - # a group inside a single `Session.run()` call, which will avoid duplicate - # computation. - reduced_update_op = _reduce_metric_variables(len(tower_specs)) eval_metric_ops[name] = (metric_tensor, reduced_update_op) - estimator_spec['eval_metric_ops'] = eval_metric_ops return model_fn_lib.EstimatorSpec(**estimator_spec) diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index 10b47fba5af0f2a036df637a4f4f996d388270c6..a83a1b84079f115f94be33297f0ab0e2e8f2f7e3 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -49,15 +49,30 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import device_setter from tensorflow.python.training import gradient_descent +# TODO(isaprykin): Parametrize all the tests on +# replicate_model_fn._VariableDistributionMode when it's supported. class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): def setUp(self): self._model_dir = tempfile.mkdtemp() - def test_complete_flow(self): + def test_complete_flow_with_public_version(self): + return self._complete_flow_with_mode(mode=None) + + def test_complete_flow_with_mode_local_ps_server(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode. + SHARED_LOCAL_PARAMETER_SERVER) + + def test_complete_flow_with_mode_round_robin(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN) + + def _complete_flow_with_mode(self, mode): n_classes = 3 input_dimension = 2 batch_size = 12 @@ -65,20 +80,35 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): data = np.linspace( 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) x_data = data.reshape(batch_size, input_dimension) + categorical_data = np.random.random_integers( + 0, len(x_data), size=len(x_data)) y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) train_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, + x={'x': x_data, + 'categories': categorical_data}, y=y_data, batch_size=batch_size, num_epochs=None, shuffle=True) eval_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False) + x={'x': x_data, + 'categories': categorical_data}, + y=y_data, + batch_size=batch_size, + shuffle=False) predict_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, batch_size=batch_size, shuffle=False) + x={'x': x_data, + 'categories': categorical_data}, + batch_size=batch_size, + shuffle=False) feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) + feature_column.numeric_column('x', shape=(input_dimension,)), + feature_column.embedding_column( + feature_column.categorical_column_with_vocabulary_list( + 'categories', + vocabulary_list=np.linspace( + 0., len(x_data), len(x_data), dtype=np.int64)), 1) ] estimator = dnn.DNNClassifier( @@ -90,14 +120,20 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): def optimizer_fn(): return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) - # TODO(isaprykin): Switch Estimator to use allow_soft_placement=True - # during export_savedmodel and then switch this test to replicate over - # GPUs instead of CPUs. + if not mode: # Use the public `replicate_model_fn`. + model_fn = replicate_model_fn.replicate_model_fn( + estimator.model_fn, + optimizer_fn, + devices=['/gpu:0', '/gpu:1', '/gpu:2']) + else: + model_fn = replicate_model_fn._replicate_model_fn_with_mode( + estimator.model_fn, + optimizer_fn, + devices=['/gpu:0', '/gpu:1', '/gpu:2'], + mode=mode) + estimator = estimator_lib.Estimator( - model_fn=replicate_model_fn.replicate_model_fn( - estimator.model_fn, - optimizer_fn, - devices=['/cpu:0', '/cpu:0', '/cpu:0']), + model_fn=model_fn, model_dir=estimator.model_dir, config=estimator.config, params=estimator.params) @@ -177,8 +213,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) # loss = feature * c - label @@ -207,8 +243,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): devices=['/gpu:0', '/gpu:1']) # This call is going to fail if `replicated_model_fn` is still passing # `params` inside `optimizer_fn`, even though the latter doesn't take any: - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) del estimator_spec def test_eval(self): @@ -218,8 +254,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, - labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) session.run(variables.local_variables_initializer()) session.run(variables.global_variables_initializer()) @@ -230,6 +266,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): accuracy = session.run(accuracy) auc = session.run(auc) + # loss[i] = features[i] * 10 - labels[i]. # Accuracy is 0.0 (no match) in the first tower. # Accuracy is 1.0 (match) in the second tower, since the feature # times weight "c" happened to be equal to the label. @@ -246,8 +283,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) session.run(variables.global_variables_initializer()) self.assertAllClose({ @@ -260,9 +297,9 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, - features, labels, self.params) + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) # loss = feature * c - label @@ -283,8 +320,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, - labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) session.run(variables.local_variables_initializer()) session.run(variables.global_variables_initializer()) @@ -311,8 +348,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) session.run(variables.global_variables_initializer()) self.assertAllClose({ @@ -346,7 +383,7 @@ class GetLossTowersTest(test_util.TensorFlowTestCase): params=None, config=None, devices=['/gpu:0', '/gpu:1'], - local_ps_device='/gpu:0', + local_ps_devices=['/gpu:0'], name_scope_pattern='test_tower_{}') session.run(variables.global_variables_initializer()) @@ -369,6 +406,54 @@ class GetLossTowersTest(test_util.TensorFlowTestCase): c = variable_scope.get_variable('c', dtype=dtypes.float64) self.assertEqual(0.25, session.run(c)) + def test_variables_are_round_robined_correctly(self): + """Test that creates multiple variables and tests round-robin placement.""" + + def model_fn(mode, features, labels, params): + del params + for variable_name in ['a', 'b', 'c', 'd']: + c = variable_scope.get_variable( + variable_name, + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=math_ops.reduce_sum(loss)) + + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + model_fn, + mode=None, + features=[[0.6], [1.6], [2.6]], + labels=[[0.6], [0.6], [2.6]], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1', '/gpu:3'], + local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 3) + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('/device:GPU:3', tower_specs[2].loss.device) + + with variable_scope.variable_scope('', reuse=True): + a = variable_scope.get_variable('a', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', a.device) + b = variable_scope.get_variable('b', dtype=dtypes.float64) + self.assertEqual('/device:GPU:1', b.device) + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:3', c.device) + d = variable_scope.get_variable('d', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', d.device) + class SplitBatchTest(test_util.TensorFlowTestCase): @@ -531,8 +616,7 @@ class EvalSpecTest(test_util.TensorFlowTestCase): self.assertEqual('/device:CPU:0', auc.device) session.run([a, b]) - accuracy = session.run(accuracy) - auc = session.run(auc) + accuracy, auc = session.run([accuracy, auc]) self.assertNear((12 - 2) / 12, accuracy, 0.01) self.assertEqual(0, auc) @@ -592,7 +676,7 @@ class PredictSpecTest(test_util.TensorFlowTestCase): params=None, config=None, devices=['/gpu:0', '/gpu:1'], - local_ps_device='/gpu:0', + local_ps_devices=['/gpu:0'], ) session.run(variables.global_variables_initializer()) @@ -766,8 +850,8 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase): replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, - features, labels, {}) + estimator_spec = replicated_model_fn(features, labels, + model_fn_lib.ModeKeys.PREDICT, {}) session.run(variables.global_variables_initializer()) return estimator_spec @@ -831,37 +915,77 @@ class GetLocalDevicesTest(test_util.TensorFlowTestCase): replicate_model_fn._get_local_devices('XPU')) # XPU doesn't exist. def test_whether_there_is_a_gpu(self): - self.assertEqual( - len(replicate_model_fn._get_local_devices('GPU')), - test.is_gpu_available()) + if test.is_gpu_available(): + self.assertTrue(len(replicate_model_fn._get_local_devices('GPU'))) class LocalDeviceSetterTest(test_util.TensorFlowTestCase): def test_vars_are_on_ps_but_ops_are_on_workers(self): + ps_devices = ['/device:GPU:3'] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + local_device_setter = replicate_model_fn._local_device_setter( - ps_device='/device:GPU:3', worker_device='/device:GPU:2') + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') with ops_lib.device(local_device_setter): - c = variables.Variable(0.01) + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:3', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:3', b.device) + + c = variables.Variable(0.03) self.assertEqual('/device:GPU:3', c.device) - cc = variables.Variable(0.02) - self.assertEqual('/device:GPU:3', cc.device) + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) + + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) - ccc = variables.Variable(0.03) - self.assertEqual('/device:GPU:3', ccc.device) + def test_round_robin_placement(self): + ps_devices = [ + '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4' + ] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + + local_device_setter = replicate_model_fn._local_device_setter( + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:0', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:1', b.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', c.device) + + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) + + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:4', c.device) + + d = variables.Variable(0.03) + self.assertEqual('/device:GPU:0', d.device) c_op = array_ops.concat(c, axis=0) self.assertEqual('/device:GPU:2', c_op.device) - cc_op = array_ops.concat(cc, axis=0) - self.assertEqual('/device:GPU:2', cc_op.device) - class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): - def test_example(self): + def test_vectors(self): with self.test_session() as session: total = replicate_model_fn._compute_sum_on_device( [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum') @@ -870,6 +994,68 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): self.assertEqual('test_sum', total.op.name) self.assertEqual(10.0, session.run(total)) + def test_tensors(self): + with self.test_session() as session: + total = replicate_model_fn._compute_sum_on_device( + [[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum') + + self.assertEqual('/device:GPU:0', total.device) + self.assertEqual('test_sum', total.op.name) + self.assertAllEqual([4.0, 6.0], session.run(total)) + + def test_indexedslices(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 1], + dense_shape=constant_op.constant([2])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([4.0, 6.0], + session.run(ops_lib.convert_to_tensor(total))) + + def test_indexedslices_higher_dimensions(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1], + dense_shape=constant_op.constant([2, 4])) + b = ops_lib.IndexedSlices( + constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]], + session.run(ops_lib.convert_to_tensor(total))) + + def test_indexedslices_some_dont_overlap(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 3], + dense_shape=constant_op.constant([4])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([4.0, 4.0, 0.0, 2.0], + session.run(ops_lib.convert_to_tensor(total))) + + def test_no_name_for_indexslices(self): + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 1], + dense_shape=constant_op.constant([2])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + with self.assertRaisesRegexp(ValueError, ''): + _ = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0', name='cant_name_indexslices') + class ConcatTensorDictsTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py index 0d67e09f8151b48c97094b6b48f26e63443707ef..f72280c4ecf19e33278ffe74061f44bbb7b21709 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm.py +++ b/tensorflow/contrib/factorization/python/ops/gmm.py @@ -24,7 +24,7 @@ import numpy as np from tensorflow.contrib import framework from tensorflow.contrib.factorization.python.ops import gmm_ops from tensorflow.contrib.framework.python.framework import checkpoint_utils -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.python.framework import constant_op @@ -167,7 +167,7 @@ class GMM(estimator.Estimator): self._num_clusters, self._random_seed, self._covariance_type, self._params) - incr_step = state_ops.assign_add(variables.get_global_step(), 1) + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) loss = math_ops.reduce_sum(losses) training_op = with_dependencies([training_op, incr_step], loss) training_hooks = [_InitializeClustersHook( diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 3976395d78e9188dd56d5b3b32fa8a3daf43c37d..4fe22ea26ec5f5a43f1c99d1fee518b1d326c5c9 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.factorization.python.ops import factorization_ops -from tensorflow.contrib.framework.python.ops import variables as framework_variables from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.framework import dtypes @@ -32,175 +31,81 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util class _SweepHook(session_run_hook.SessionRunHook): """Keeps track of row/col sweeps, and runs prep ops before each sweep.""" - def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols, - input_row_indices, input_col_indices, row_prep_ops, - col_prep_ops, init_op, completed_sweeps_var): + def __init__(self, is_row_sweep_var, is_sweep_done_var, init_op, + row_prep_ops, col_prep_ops, row_train_op, col_train_op, + switch_op): """Initializes SweepHook. Args: is_row_sweep_var: A Boolean tf.Variable, determines whether we are currently doing a row or column sweep. It is updated by the hook. - train_ops: A list of ops. The ops created by this hook will have - control dependencies on `train_ops`. - num_rows: int, the total number of rows to be processed. - num_cols: int, the total number of columns to be processed. - input_row_indices: A Tensor of type int64. The indices of the input rows - that are processed during the current sweep. All elements of - `input_row_indices` must be in [0, num_rows). - input_col_indices: A Tensor of type int64. The indices of the input - columns that are processed during the current sweep. All elements of - `input_col_indices` must be in [0, num_cols). - row_prep_ops: list of ops, to be run before the beginning of each row - sweep, in the given order. - col_prep_ops: list of ops, to be run before the beginning of each column - sweep, in the given order. + is_sweep_done_var: A Boolean tf.Variable, determines whether we are + starting a new sweep (this is used to determine when to run the prep ops + below). init_op: op to be run once before training. This is typically a local initialization op (such as cache initialization). - completed_sweeps_var: An integer tf.Variable, indicates the number of - completed sweeps. It is updated by the hook. + row_prep_ops: A list of TensorFlow ops, to be run before the beginning of + each row sweep (and during initialization), in the given order. + col_prep_ops: A list of TensorFlow ops, to be run before the beginning of + each column sweep (and during initialization), in the given order. + row_train_op: A TensorFlow op to be run during row sweeps. + col_train_op: A TensorFlow op to be run during column sweeps. + switch_op: A TensorFlow op to be run before each sweep. """ - self._num_rows = num_rows - self._num_cols = num_cols + self._is_row_sweep_var = is_row_sweep_var + self._is_sweep_done_var = is_sweep_done_var + self._init_op = init_op self._row_prep_ops = row_prep_ops self._col_prep_ops = col_prep_ops - self._init_op = init_op - self._is_row_sweep_var = is_row_sweep_var - self._completed_sweeps_var = completed_sweeps_var - # Boolean variable that determines whether the init_ops have been run. + self._row_train_op = row_train_op + self._col_train_op = col_train_op + self._switch_op = switch_op + # Boolean variable that determines whether the init_op has been run. self._is_initialized = False - # Ops to run jointly with train_ops, responsible for updating - # `is_row_sweep_var` and incrementing the `global_step` and - # `completed_sweeps` counters. - self._update_op, self._is_sweep_done_var, self._switch_op = ( - self._create_hook_ops(input_row_indices, input_col_indices, train_ops)) - - def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops): - """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. - - Creates two boolean tensors `processed_rows` and `processed_cols`, which - keep track of which rows/cols have been processed during the current sweep. - Returns ops that should be run after each row / col update. - - When `self._is_row_sweep_var` is True, it sets - processed_rows[input_row_indices] to True. - - When `self._is_row_sweep_var` is False, it sets - processed_cols[input_col_indices] to True. - - Args: - input_row_indices: A Tensor. The indices of the input rows that are - processed during the current sweep. - input_col_indices: A Tensor. The indices of the input columns that - are processed during the current sweep. - train_ops: A list of ops. The ops created by this function have control - dependencies on `train_ops`. - - Returns: - A tuple consisting of: - update_op: An op to be run jointly with training. It updates the state - and increments counters (global step and completed sweeps). - is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is - done, i.e. all rows (during a row sweep) or all columns (during a - column sweep) have been processed. - switch_op: An op to be run in `self.before_run` when the sweep is done. - """ - processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) - with ops.colocate_with(processed_rows_init): - processed_rows = variable_scope.variable( - processed_rows_init, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="sweep_hook_processed_rows") - processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False) - with ops.colocate_with(processed_cols_init): - processed_cols = variable_scope.variable( - processed_cols_init, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="sweep_hook_processed_cols") - switch_ops = control_flow_ops.group( - state_ops.assign( - self._is_row_sweep_var, - math_ops.logical_not(self._is_row_sweep_var)), - state_ops.assign(processed_rows, processed_rows_init), - state_ops.assign(processed_cols, processed_cols_init)) - is_sweep_done_var = variable_scope.variable( - False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="is_sweep_done") - - # After running the `train_ops`, updates `processed_rows` or - # `processed_cols` tensors, depending on whether this is a row or col sweep. - with ops.control_dependencies(train_ops): - with ops.colocate_with(processed_rows): - update_processed_rows = state_ops.scatter_update( - processed_rows, - input_row_indices, - math_ops.logical_and( - self._is_row_sweep_var, - array_ops.ones_like(input_row_indices, dtype=dtypes.bool))) - with ops.colocate_with(processed_cols): - update_processed_cols = state_ops.scatter_update( - processed_cols, - input_col_indices, - math_ops.logical_and( - math_ops.logical_not(self._is_row_sweep_var), - array_ops.ones_like(input_col_indices, dtype=dtypes.bool))) - update_processed_op = control_flow_ops.group( - update_processed_rows, update_processed_cols) - - with ops.control_dependencies([update_processed_op]): - is_sweep_done = math_ops.logical_or( - math_ops.reduce_all(processed_rows), - math_ops.reduce_all(processed_cols)) - # Increments global step. - global_step = framework_variables.get_global_step() - if global_step is not None: - global_step_incr_op = state_ops.assign_add( - global_step, 1, name="global_step_incr").op - else: - global_step_incr_op = control_flow_ops.no_op() - # Increments completed sweeps. - completed_sweeps_incr_op = state_ops.assign_add( - self._completed_sweeps_var, - math_ops.cast(is_sweep_done, dtypes.int32), - use_locking=True).op - update_ops = control_flow_ops.group( - global_step_incr_op, - completed_sweeps_incr_op, - state_ops.assign(is_sweep_done_var, is_sweep_done)) - - return update_ops, is_sweep_done_var, switch_ops def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" - # Runs the appropriate init ops and prep ops. sess = run_context.session is_sweep_done = sess.run(self._is_sweep_done_var) if not self._is_initialized: - logging.info("SweepHook running cache init op.") + logging.info("SweepHook running init op.") sess.run(self._init_op) if is_sweep_done: + logging.info("SweepHook starting the next sweep.") sess.run(self._switch_op) + is_row_sweep = sess.run(self._is_row_sweep_var) if is_sweep_done or not self._is_initialized: - logging.info("SweepHook running sweep prep ops.") - row_sweep = sess.run(self._is_row_sweep_var) - prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops + logging.info("SweepHook running prep ops for the {} sweep.".format( + "row" if is_row_sweep else "col")) + prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops for prep_op in prep_ops: sess.run(prep_op) - self._is_initialized = True - - # Requests running `self._update_op` jointly with the training op. logging.info("Next fit step starting.") - return session_run_hook.SessionRunArgs(fetches=[self._update_op]) + return session_run_hook.SessionRunArgs( + fetches=[self._row_train_op if is_row_sweep else self._col_train_op]) - def after_run(self, run_context, run_values): - logging.info("Fit step done.") + +class _IncrementGlobalStepHook(session_run_hook.SessionRunHook): + """Hook that increments the global step.""" + + def __init__(self): + global_step = training_util.get_global_step() + if global_step: + self._global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + self._global_step_incr_op = None + + def before_run(self, run_context): + if self._global_step_incr_op: + run_context.session.run(self._global_step_incr_op) class _StopAtSweepHook(session_run_hook.SessionRunHook): @@ -246,6 +151,9 @@ def _wals_factorization_model_function(features, labels, mode, params): Returns: A ModelFnOps object. + + Raises: + ValueError: If `mode` is not recognized. """ assert labels is None use_factors_weights_cache = (params["use_factors_weights_cache_for_training"] @@ -269,86 +177,145 @@ def _wals_factorization_model_function(features, labels, mode, params): use_gramian_cache=use_gramian_cache) # Get input rows and cols. We either update rows or columns depending on - # the value of row_sweep, which is maintained using a session hook + # the value of row_sweep, which is maintained using a session hook. input_rows = features[WALSMatrixFactorization.INPUT_ROWS] input_cols = features[WALSMatrixFactorization.INPUT_COLS] - input_row_indices, _ = array_ops.unique(input_rows.indices[:, 0]) - input_col_indices, _ = array_ops.unique(input_cols.indices[:, 0]) - - # Train ops, controlled using the SweepHook - # We need to run the following ops: - # Before a row sweep: - # row_update_prep_gramian_op - # initialize_row_update_op - # During a row sweep: - # update_row_factors_op - # Before a col sweep: - # col_update_prep_gramian_op - # initialize_col_update_op - # During a col sweep: - # update_col_factors_op - - is_row_sweep_var = variable_scope.variable( - True, - trainable=False, - name="is_row_sweep", - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - completed_sweeps_var = variable_scope.variable( - 0, - trainable=False, - name=WALSMatrixFactorization.COMPLETED_SWEEPS, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - - # The row sweep is determined by is_row_sweep_var (controlled by the - # sweep_hook) in TRAIN mode, and manually in EVAL mode. - is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW] - if mode == model_fn.ModeKeys.EVAL else is_row_sweep_var) - - def update_row_factors(): - return model.update_row_factors(sp_input=input_rows, transpose_input=False) - - def update_col_factors(): - return model.update_col_factors(sp_input=input_cols, transpose_input=True) - - (_, train_op, - unregularized_loss, regularization, sum_weights) = control_flow_ops.cond( - is_row_sweep, update_row_factors, update_col_factors) - loss = unregularized_loss + regularization - root_weighted_squared_error = math_ops.sqrt(unregularized_loss / sum_weights) - - row_prep_ops = [ - model.row_update_prep_gramian_op, model.initialize_row_update_op - ] - col_prep_ops = [ - model.col_update_prep_gramian_op, model.initialize_col_update_op - ] - init_ops = [model.worker_init] - - sweep_hook = _SweepHook( - is_row_sweep_var, - [train_op, loss], - params["num_rows"], - params["num_cols"], - input_row_indices, - input_col_indices, - row_prep_ops, - col_prep_ops, - init_ops, - completed_sweeps_var) - training_hooks = [sweep_hook] - if max_sweeps is not None: - training_hooks.append(_StopAtSweepHook(max_sweeps)) - - # The root weighted squared error = - # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) - summary.scalar("loss", loss) # the estimated total training loss - summary.scalar("root_weighted_squared_error", root_weighted_squared_error) - summary.scalar("completed_sweeps", completed_sweeps_var) - - # Prediction ops (only return predictions in INFER mode) - predictions = {} - if mode == model_fn.ModeKeys.INFER: - project_row = features[WALSMatrixFactorization.PROJECT_ROW] + + # TRAIN mode: + if mode == model_fn.ModeKeys.TRAIN: + # Training consists of the following ops (controlled using a SweepHook). + # Before a row sweep: + # row_update_prep_gramian_op + # initialize_row_update_op + # During a row sweep: + # update_row_factors_op + # Before a col sweep: + # col_update_prep_gramian_op + # initialize_col_update_op + # During a col sweep: + # update_col_factors_op + + is_row_sweep_var = variable_scope.variable( + True, + trainable=False, + name="is_row_sweep", + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + is_sweep_done_var = variable_scope.variable( + False, + trainable=False, + name="is_sweep_done", + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + completed_sweeps_var = variable_scope.variable( + 0, + trainable=False, + name=WALSMatrixFactorization.COMPLETED_SWEEPS, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + loss_var = variable_scope.variable( + 0., + trainable=False, + name=WALSMatrixFactorization.LOSS, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + # The root weighted squared error = + # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) + rwse_var = variable_scope.variable( + 0., + trainable=False, + name=WALSMatrixFactorization.RWSE, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + + summary.scalar("loss", loss_var) + summary.scalar("root_weighted_squared_error", rwse_var) + summary.scalar("completed_sweeps", completed_sweeps_var) + + def create_axis_ops(sp_input, num_items, update_fn, axis_name): + """Creates book-keeping and training ops for a given axis. + + Args: + sp_input: A SparseTensor corresponding to the row or column batch. + num_items: An integer, the total number of items of this axis. + update_fn: A function that takes one argument (`sp_input`), and that + returns a tuple of + * new_factors: A flot Tensor of the factor values after update. + * update_op: a TensorFlow op which updates the factors. + * loss: A float Tensor, the unregularized loss. + * reg_loss: A float Tensor, the regularization loss. + * sum_weights: A float Tensor, the sum of factor weights. + axis_name: A string that specifies the name of the axis. + + Returns: + A tuple consisting of: + * reset_processed_items_op: A TensorFlow op, to be run before the + beginning of any sweep. It marks all items as not-processed. + * axis_train_op: A Tensorflow op, to be run during this axis' sweeps. + """ + processed_items_init = array_ops.fill(dims=[num_items], value=False) + with ops.colocate_with(processed_items_init): + processed_items = variable_scope.variable( + processed_items_init, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False, + name="processed_" + axis_name) + _, update_op, loss, reg, sum_weights = update_fn(sp_input) + input_indices = sp_input.indices[:, 0] + with ops.control_dependencies([ + update_op, + state_ops.assign(loss_var, loss + reg), + state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]): + with ops.colocate_with(processed_items): + update_processed_items = state_ops.scatter_update( + processed_items, + input_indices, + array_ops.ones_like(input_indices, dtype=dtypes.bool), + name="update_processed_{}_indices".format(axis_name)) + with ops.control_dependencies([update_processed_items]): + is_sweep_done = math_ops.reduce_all(processed_items) + axis_train_op = control_flow_ops.group( + state_ops.assign(is_sweep_done_var, is_sweep_done), + state_ops.assign_add( + completed_sweeps_var, + math_ops.cast(is_sweep_done, dtypes.int32)), + name="{}_sweep_train_op".format(axis_name)) + return processed_items.initializer, axis_train_op + + reset_processed_rows_op, row_train_op = create_axis_ops( + input_rows, + params["num_rows"], + lambda x: model.update_row_factors(sp_input=x, transpose_input=False), + "rows") + reset_processed_cols_op, col_train_op = create_axis_ops( + input_cols, + params["num_cols"], + lambda x: model.update_col_factors(sp_input=x, transpose_input=True), + "cols") + switch_op = control_flow_ops.group( + state_ops.assign( + is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)), + reset_processed_rows_op, + reset_processed_cols_op, + name="sweep_switch_op") + row_prep_ops = [ + model.row_update_prep_gramian_op, model.initialize_row_update_op] + col_prep_ops = [ + model.col_update_prep_gramian_op, model.initialize_col_update_op] + init_op = model.worker_init + sweep_hook = _SweepHook( + is_row_sweep_var, is_sweep_done_var, init_op, + row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op) + global_step_hook = _IncrementGlobalStepHook() + training_hooks = [sweep_hook, global_step_hook] + if max_sweeps is not None: + training_hooks.append(_StopAtSweepHook(max_sweeps)) + + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.TRAIN, + predictions={}, + loss=loss_var, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=training_hooks) + + # INFER mode + elif mode == model_fn.ModeKeys.INFER: projection_weights = features.get( WALSMatrixFactorization.PROJECTION_WEIGHTS) @@ -364,17 +331,45 @@ def _wals_factorization_model_function(features, labels, mode, params): projection_weights=projection_weights, transpose_input=True) - predictions[WALSMatrixFactorization.PROJECTION_RESULT] = ( - control_flow_ops.cond(project_row, get_row_projection, - get_col_projection)) + predictions = { + WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond( + features[WALSMatrixFactorization.PROJECT_ROW], + get_row_projection, + get_col_projection) + } - return model_fn.ModelFnOps( - mode=mode, - predictions=predictions, - loss=loss, - eval_metric_ops={}, - train_op=train_op, - training_hooks=training_hooks) + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.INFER, + predictions=predictions, + loss=None, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=[]) + + # EVAL mode + elif mode == model_fn.ModeKeys.EVAL: + def get_row_loss(): + _, _, loss, reg, _ = model.update_row_factors( + sp_input=input_rows, transpose_input=False) + return loss + reg + def get_col_loss(): + _, _, loss, reg, _ = model.update_col_factors( + sp_input=input_cols, transpose_input=True) + return loss + reg + loss = control_flow_ops.cond( + features[WALSMatrixFactorization.PROJECT_ROW], + get_row_loss, + get_col_loss) + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.EVAL, + predictions={}, + loss=loss, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=[]) + + else: + raise ValueError("mode=%s is not recognized." % str(mode)) class WALSMatrixFactorization(estimator.Estimator): @@ -452,6 +447,10 @@ class WALSMatrixFactorization(estimator.Estimator): PROJECTION_RESULT = "projection" # Name of the completed_sweeps variable COMPLETED_SWEEPS = "completed_sweeps" + # Name of the loss variable + LOSS = "WALS_loss" + # Name of the Root Weighted Squared Error variable + RWSE = "WALS_RWSE" def __init__(self, num_rows, diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 8bd72b7025aad80e387171b93b9b264da3ed0f66..36b483c6d7a59bba78b7fa22aac0714e278f22cc 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -417,73 +417,67 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase): class SweepHookTest(test.TestCase): - def setUp(self): - self._num_rows = 5 - self._num_cols = 7 - self._train_op = control_flow_ops.no_op() - self._row_prep_done = variables.Variable(False) - self._col_prep_done = variables.Variable(False) - self._init_done = variables.Variable(False) - self._row_prep_ops = [state_ops.assign(self._row_prep_done, True)] - self._col_prep_ops = [state_ops.assign(self._col_prep_done, True)] - self._init_ops = [state_ops.assign(self._init_done, True)] - self._input_row_indices_ph = array_ops.placeholder(dtypes.int64) - self._input_col_indices_ph = array_ops.placeholder(dtypes.int64) - def test_sweeps(self): - def ind_feed(row_indices, col_indices): - return { - self._input_row_indices_ph: row_indices, - self._input_col_indices_ph: col_indices - } + is_row_sweep_var = variables.Variable(True) + is_sweep_done_var = variables.Variable(False) + init_done = variables.Variable(False) + row_prep_done = variables.Variable(False) + col_prep_done = variables.Variable(False) + row_train_done = variables.Variable(False) + col_train_done = variables.Variable(False) + + init_op = state_ops.assign(init_done, True) + row_prep_op = state_ops.assign(row_prep_done, True) + col_prep_op = state_ops.assign(col_prep_done, True) + row_train_op = state_ops.assign(row_train_done, True) + col_train_op = state_ops.assign(col_train_done, True) + train_op = control_flow_ops.no_op() + switch_op = control_flow_ops.group( + state_ops.assign(is_sweep_done_var, False), + state_ops.assign(is_row_sweep_var, + math_ops.logical_not(is_row_sweep_var))) + mark_sweep_done = state_ops.assign(is_sweep_done_var, True) with self.test_session() as sess: - is_row_sweep_var = variables.Variable(True) - completed_sweeps_var = variables.Variable(0) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, - [self._train_op], - self._num_rows, - self._num_cols, - self._input_row_indices_ph, - self._input_col_indices_ph, - self._row_prep_ops, - self._col_prep_ops, - self._init_ops, - completed_sweeps_var) + is_sweep_done_var, + init_op, + [row_prep_op], + [col_prep_op], + row_train_op, + col_train_op, + switch_op) mon_sess = monitored_session._HookedSession(sess, [sweep_hook]) sess.run([variables.global_variables_initializer()]) - # Init ops should run before the first run. Row sweep not completed. - mon_sess.run(self._train_op, ind_feed([0, 1, 2], [])) - self.assertTrue(sess.run(self._init_done), - msg='init ops not run by the sweep_hook') - self.assertTrue(sess.run(self._row_prep_done), - msg='row_prep not run by the sweep_hook') - self.assertTrue(sess.run(is_row_sweep_var), - msg='Row sweep is not complete but is_row_sweep is ' - 'False.') - # Row sweep completed. - mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) - self.assertTrue(sess.run(completed_sweeps_var) == 1, - msg='Completed sweeps should be equal to 1.') - self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is complete but is_sweep_done is False.') - # Col init ops should run. Col sweep not completed. - mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4])) - self.assertTrue(sess.run(self._col_prep_done), - msg='col_prep not run by the sweep_hook') - self.assertFalse(sess.run(is_row_sweep_var), - msg='Col sweep is not complete but is_row_sweep is ' - 'True.') - self.assertFalse(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is not complete but is_sweep_done is True.') - # Col sweep completed. - mon_sess.run(self._train_op, ind_feed([], [4, 5, 6])) - self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is complete but is_sweep_done is False.') - self.assertTrue(sess.run(completed_sweeps_var) == 2, - msg='Completed sweeps should be equal to 2.') + # Row sweep. + mon_sess.run(train_op) + self.assertTrue(sess.run(init_done), + msg='init op not run by the Sweephook') + self.assertTrue(sess.run(row_prep_done), + msg='row_prep_op not run by the SweepHook') + self.assertTrue(sess.run(row_train_done), + msg='row_train_op not run by the SweepHook') + self.assertTrue( + sess.run(is_row_sweep_var), + msg='Row sweep is not complete but is_row_sweep_var is False.') + # Col sweep. + mon_sess.run(mark_sweep_done) + mon_sess.run(train_op) + self.assertTrue(sess.run(col_prep_done), + msg='col_prep_op not run by the SweepHook') + self.assertTrue(sess.run(col_train_done), + msg='col_train_op not run by the SweepHook') + self.assertFalse( + sess.run(is_row_sweep_var), + msg='Col sweep is not complete but is_row_sweep_var is True.') + # Row sweep. + mon_sess.run(mark_sweep_done) + mon_sess.run(train_op) + self.assertTrue( + sess.run(is_row_sweep_var), + msg='Col sweep is complete but is_row_sweep_var is False.') class StopAtSweepHookTest(test.TestCase): diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index 7a5a4cb8c9499b950a3ad89be710e48474d5791e..eccce99071dc1477cf4f3bb152f3304b3b0fc35a 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -47,10 +47,25 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "decode_video_op_cc", + srcs = ["decode_video_op.cc"], + copts = tf_copts(), + linkstatic = 1, + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/contrib/ffmpeg/default:ffmpeg_lib", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + ], + alwayslink = 1, +) + tf_custom_op_library( name = "ffmpeg.so", deps = [ ":decode_audio_op_cc", + ":decode_video_op_cc", ":encode_audio_op_cc", ], ) @@ -59,6 +74,7 @@ cc_library( name = "ffmpeg_op_lib", deps = [ ":decode_audio_op_cc", + ":decode_video_op_cc", ":encode_audio_op_cc", ], ) @@ -81,6 +97,15 @@ tf_gen_op_wrapper_py( ], ) +tf_gen_op_wrapper_py( + name = "decode_video_op_py", + require_shape_functions = True, + visibility = ["//visibility:private"], + deps = [ + ":decode_video_op_cc", + ], +) + tf_py_test( name = "decode_audio_op_test", srcs = ["decode_audio_op_test.py"], @@ -115,6 +140,27 @@ tf_py_test( tags = ["manual"], ) +tf_py_test( + name = "decode_video_op_test", + size = "small", + srcs = ["decode_video_op_test.py"], + additional_deps = [ + ":ffmpeg_ops_py", + "@six_archive//:six", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform", + "//tensorflow/python:image_ops", + ], + data = [ + ":test_data", + ], + tags = [ + "manual", + "notap", + ], +) + py_library( name = "ffmpeg_ops_py", srcs = [ @@ -126,6 +172,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":decode_audio_op_py", + ":decode_video_op_py", ":encode_audio_op_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index 2bcb7284e10991b19ee5607147371e8d505c7732..daba965a98893b992abdc598ec713f13020d6e91 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -26,9 +26,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio +from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio +from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['decode_audio', 'encode_audio'] +_allowed_symbols = ['decode_audio', 'encode_audio', 'decode_video'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc index 4b1c8a337e10c7025ca06e2ed6e1b934716dc1d0..92fad70b1f9cc55e0690a3fbb35abcf56aa68f16 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc @@ -37,29 +37,6 @@ namespace { // https://www.ffmpeg.org/ffmpeg-formats.html const char* kValidFileFormats[] = {"mp3", "mp4", "ogg", "wav"}; -// Writes binary data to a file. -Status WriteFile(const string& filename, tensorflow::StringPiece contents) { - Env& env = *Env::Default(); - std::unique_ptr file; - TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file)); - TF_RETURN_IF_ERROR(file->Append(contents)); - TF_RETURN_IF_ERROR(file->Close()); - return Status::OK(); -} - -// Cleans up a file on destruction. -class FileDeleter { - public: - explicit FileDeleter(const string& filename) : filename_(filename) {} - ~FileDeleter() { - Env& env = *Env::Default(); - env.DeleteFile(filename_).IgnoreError(); - } - - private: - const string filename_; -}; - /* * Decoding implementation, shared across V1 and V2 ops. Creates a new * output in the context. @@ -69,7 +46,7 @@ void Decode(OpKernelContext* context, const string& file_format, const int32 samples_per_second, const int32 channel_count) { // Write the input data to a temp file. - const string temp_filename = GetTempFilename(file_format); + const string temp_filename = io::GetTempFilename(file_format); OP_REQUIRES_OK(context, WriteFile(temp_filename, file_contents)); FileDeleter deleter(temp_filename); diff --git a/tensorflow/contrib/ffmpeg/decode_video_op.cc b/tensorflow/contrib/ffmpeg/decode_video_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d44032968d559bec14722902a4d47d22c46ea4aa --- /dev/null +++ b/tensorflow/contrib/ffmpeg/decode_video_op.cc @@ -0,0 +1,118 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include + +#include +#include + +#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace ffmpeg { + +class DecodeVideoOp : public OpKernel { + public: + explicit DecodeVideoOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + OP_REQUIRES( + context, context->num_inputs() == 1, + errors::InvalidArgument("DecodeVideo requires exactly 1 input.")); + const Tensor& contents_tensor = context->input(0); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_tensor.shape()), + errors::InvalidArgument( + "contents must be a rank-0 tensor but got shape ", + contents_tensor.shape().DebugString())); + const tensorflow::StringPiece contents = contents_tensor.scalar()(); + + // Write the input data to a temp file. + string extension; + const string temp_filename = io::GetTempFilename(extension); + OP_REQUIRES_OK(context, WriteFile(temp_filename, contents)); + FileDeleter deleter(temp_filename); + + uint32 width = 0; + uint32 height = 0; + uint32 frames = 0; + + // Run FFmpeg on the data and verify results. + std::vector output_data; + const Status result = ffmpeg::ReadVideoFile(temp_filename, &output_data, + &width, &height, &frames); + if (result.code() == error::Code::NOT_FOUND) { + OP_REQUIRES( + context, result.ok(), + errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg " + "can be found at http://www.ffmpeg.org.")); + } else if (result.code() == error::UNKNOWN) { + LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message() + << "'. Returning empty tensor."; + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({0, 0}), &output)); + return; + } else { + OP_REQUIRES_OK(context, result); + } + OP_REQUIRES(context, !output_data.empty(), + errors::Unknown("No output created by FFmpeg.")); + OP_REQUIRES( + context, output_data.size() == (frames * height * width * 3), + errors::Unknown("Output created by FFmpeg [", output_data.size(), + "] does not match description [", frames, ", ", height, + ", ", width, ", 3]")); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({frames, height, width, 3}), &output)); + auto output_flat = output->flat(); + std::copy_n(output_data.begin(), output_data.size(), &output_flat(0)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("DecodeVideo").Device(DEVICE_CPU), DecodeVideoOp); + +REGISTER_OP("DecodeVideo") + .Input("contents: string") + .Output("output: uint8") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->UnknownShapeOfRank(4)); + return Status::OK(); + }) + .Doc(R"doc( +Processes the contents of an audio file into a tensor using FFmpeg to decode +the file. + +One row of the tensor is created for each channel in the audio file. Each +channel contains audio samples starting at the beginning of the audio and +having `1/samples_per_second` time between them. If the `channel_count` is +different from the contents of the file, channels will be merged or created. + +contents: The binary audio file contents, as a string or rank-0 string + tensor. +)doc"); + +} // namespace ffmpeg +} // namespace tensorflow diff --git a/tensorflow/contrib/ffmpeg/decode_video_op_test.py b/tensorflow/contrib/ffmpeg/decode_video_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b43b6b8919223bd7731209d5423b142601396ea5 --- /dev/null +++ b/tensorflow/contrib/ffmpeg/decode_video_op_test.py @@ -0,0 +1,69 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for third_party.tensorflow.contrib.ffmpeg.decode_video_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path + +import six # pylint: disable=unused-import + +from tensorflow.contrib import ffmpeg +from tensorflow.python.ops import image_ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test + + +class DecodeVideoOpTest(test.TestCase): + + def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, + index): + """Loads an video file and validates the output tensor. + + Args: + filename: The filename of the input file. + width: The width of the video. + height: The height of the video. + frames: The frames of the video. + bmp_filename: The filename for the bmp file. + index: Index location inside the video. + """ + with self.test_session(): + path = os.path.join(resource_loader.get_data_files_path(), 'testdata', + filename) + with open(path, 'rb') as f: + contents = f.read() + + bmp_path = os.path.join(resource_loader.get_data_files_path(), 'testdata', + bmp_filename) + with open(bmp_path, 'rb') as f: + bmp_contents = f.read() + + image_op = image_ops.decode_bmp(bmp_contents) + image = image_op.eval() + self.assertEqual(image.shape, (height, width, 3)) + video_op = ffmpeg.decode_video(contents) + video = video_op.eval() + self.assertEqual(video.shape, (frames, height, width, 3)) + self.assertAllEqual(video[index, :, :, :], image) + + def testMp4(self): + self._loadFileAndTest('small.mp4', 560, 320, 166, 'small_100.bmp', 99) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index 545a4386d043af604a747b8b5a8103101812b177..1e8af1458cea13b2ddb89b7d93a4ffb8b974ecd2 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -16,6 +16,7 @@ #include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h" #include +#include #include #include #include @@ -25,6 +26,7 @@ #include #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" @@ -38,28 +40,47 @@ namespace { const char kFfmpegExecutable[] = "ffmpeg"; const int32 kDefaultProbeSize = 5000000; // 5MB -std::vector FfmpegCommandLine(const string& input_filename, - const string& output_filename, - const string& input_format_id, - int32 samples_per_second, - int32 channel_count) { - return { - "-nostats", // No additional progress display. - "-nostdin", // No interactive commands accepted. - "-f", input_format_id, // eg: "mp3" - "-probesize", StrCat(kDefaultProbeSize), - "-i", input_filename, - "-loglevel", "info", // Enable verbose logging to support debugging. - "-map_metadata", "-1", // Copy global metadata from input to output. - "-vn", // No video recording. - "-ac:a:0", StrCat(channel_count), - "-ar:a:0", StrCat(samples_per_second), - // Output set (in several ways) to signed 16-bit little-endian ints. - "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", - "-sn", // No subtitle recording. - "-y", // Overwrite output file. - StrCat(output_filename) - }; +std::vector FfmpegAudioCommandLine(const string& input_filename, + const string& output_filename, + const string& input_format_id, + int32 samples_per_second, + int32 channel_count) { + return {"-nostats", // No additional progress display. + "-nostdin", // No interactive commands accepted. + "-f", input_format_id, // eg: "mp3" + "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, + "-loglevel", "error", // Print errors only. + "-hide_banner", // Skip printing build options, version, etc. + "-map_metadata", "-1", // Copy global metadata from input to output. + "-vn", // No video recording. + "-ac:a:0", StrCat(channel_count), "-ar:a:0", + StrCat(samples_per_second), + // Output set (in several ways) to signed 16-bit little-endian ints. + "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", + "-sn", // No subtitle recording. + "-y", // Overwrite output file. + StrCat(output_filename)}; +} + +std::vector FfmpegVideoCommandLine(const string& input_filename, + const string& output_filename) { + return {"-nostats", // No additional progress display. + "-nostdin", // No interactive commands accepted. + "-i", + input_filename, + "-f", + "image2pipe", + "-probesize", + StrCat(kDefaultProbeSize), + "-loglevel", + "error", // Print errors only. + "-hide_banner", // Skip printing build options, version, etc. + "-vcodec", + "rawvideo", + "-pix_fmt", + "rgb24", + "-y", // Overwrite output file. + StrCat(output_filename)}; } // Is a named binary installed and executable by the current process? @@ -106,7 +127,7 @@ bool IsBinaryInstalled(const string& binary_name) { ::execvp(kFfmpegExecutable, args_chars.data()); // exec only returns on error. const int error = errno; - LOG(ERROR) << "FFmpeg could not be executed: " << error; + LOG(ERROR) << "FFmpeg could not be executed: " << strerror(error); ::_exit(error); } @@ -198,52 +219,101 @@ string BuildWavFile(int32 samples_per_second, int32 channel_count, return data; } -// Returns a unique number every time it is called. -int64 UniqueId() { - static mutex mu(LINKER_INITIALIZED); - static int64 id = 0; - mutex_lock l(mu); - return ++id; -} - -} // namespace - -string GetTempFilename(const string& extension) { - for (const char* dir : std::vector( - {getenv("TEST_TMPDIR"), getenv("TMPDIR"), getenv("TMP"), "/tmp"})) { - if (!dir || !dir[0]) { +Status ReadInfoFile(const string& filename, uint32* width, uint32* height, + uint32* frames) { + string data; + TF_QCHECK_OK(ReadFileToString(Env::Default(), filename, &data)) + << "Could not read FFmpeg file: " << filename; + bool in_output = false; + bool in_mapping = false; + uint32 frames_value = 0; + uint32 height_value = 0; + uint32 width_value = 0; + for (const string& line : str_util::Split(data, '\n')) { + // Output starts with the first line of `Output #..`. + // Further processing output region starts next line so we could continue + // the loop. + if (!in_output && line.find("Output #") == 0) { + in_output = true; + in_mapping = false; continue; } - struct stat statbuf; - if (!stat(dir, &statbuf) && S_ISDIR(statbuf.st_mode)) { - // UniqueId is added here because mkstemps is not as thread safe as it - // looks. https://github.com/tensorflow/tensorflow/issues/5804 shows - // the problem. - string tmp_filepath = io::JoinPath( - dir, - StrCat("tmp_file_tensorflow_", UniqueId(), "_XXXXXX.", extension)); - int fd = mkstemps(&tmp_filepath[0], extension.length() + 1); - if (fd < 0) { - LOG(FATAL) << "Failed to create temp file."; - } else { - close(fd); - return tmp_filepath; + // Stream mapping starts with the first line of `Stream mapping`, it also + // signals the end of Output section. + // Further processing of stream mapping region starts next line so we could + // continue the loop. + if (!in_mapping && line.find("Stream mapping:") == 0) { + in_output = false; + in_mapping = true; + continue; + } + if (in_output) { + // We only look for the first stream in output `Stream #0`. + // Once processed we will not further process output section. + if (line.find(" Stream #") == 0) { + size_t p = line.find(", rgb24, ", 24); + if (p != std::string::npos) { + string rgb24 = line.substr(p + 9, line.find(" ", p + 9)); + rgb24 = rgb24.substr(0, rgb24.find(",")); + string rgb24_width = rgb24.substr(0, rgb24.find("x")); + string rgb24_height = rgb24.substr(rgb24_width.length() + 1); + if (strings::safe_strtou32(rgb24_width, &width_value) && + strings::safe_strtou32(rgb24_height, &height_value)) { + in_output = false; + } + } + } + continue; + } + if (in_mapping) { + // We only look for the first stream mapping to have the number of the + // frames. + // Once processed we will not further process stream mapping section. + if (line.find("frame= ") == 0) { + string number = line.substr(8, line.find(" ", 8)); + number = number.substr(0, number.find(" ")); + if (strings::safe_strtou32(number, &frames_value)) { + in_mapping = false; + } } + continue; } } - LOG(FATAL) << "No temp directory found."; + if (frames_value == 0 || height_value == 0 || width_value == 0) { + return errors::Unknown("Not enough video info returned by FFmpeg [", + frames_value, ", ", height_value, ", ", width_value, + ", 3]"); + } + *width = width_value; + *height = height_value; + *frames = frames_value; + return Status::OK(); } -Status ReadAudioFile(const string& filename, - const string& audio_format_id, - int32 samples_per_second, - int32 channel_count, +} // namespace + +FileDeleter::~FileDeleter() { + Env& env = *Env::Default(); + env.DeleteFile(filename_).IgnoreError(); +} + +Status WriteFile(const string& filename, StringPiece contents) { + Env& env = *Env::Default(); + std::unique_ptr file; + TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file)); + TF_RETURN_IF_ERROR(file->Append(contents)); + TF_RETURN_IF_ERROR(file->Close()); + return Status::OK(); +} + +Status ReadAudioFile(const string& filename, const string& audio_format_id, + int32 samples_per_second, int32 channel_count, std::vector* output_samples) { // Create an argument list. - string output_filename = GetTempFilename("raw"); + string output_filename = io::GetTempFilename("raw"); const std::vector args = - FfmpegCommandLine(filename, output_filename, audio_format_id, - samples_per_second, channel_count); + FfmpegAudioCommandLine(filename, output_filename, audio_format_id, + samples_per_second, channel_count); // Unfortunately, it's impossible to differentiate an exec failure due to the // binary being missing and an error from the binary's execution. Therefore, @@ -256,7 +326,8 @@ Status ReadAudioFile(const string& filename, // Execute ffmpeg and report errors. pid_t child_pid = ::fork(); if (child_pid < 0) { - return Status(error::Code::UNKNOWN, StrCat("fork failed: ", errno)); + return Status(error::Code::UNKNOWN, + StrCat("fork failed: ", strerror(errno))); } if (child_pid == 0) { ExecuteFfmpeg(args); @@ -285,5 +356,63 @@ Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second, return Status::OK(); } +Status ReadVideoFile(const string& filename, std::vector* output_data, + uint32* width, uint32* height, uint32* frames) { + if (!IsBinaryInstalled(kFfmpegExecutable)) { + return Status(error::Code::NOT_FOUND, StrCat("FFmpeg could not be found.")); + } + + string output_filename = io::GetTempFilename("raw"); + string stderr_filename = io::GetTempFilename("err"); + + // Create an argument list. + const std::vector args = + FfmpegVideoCommandLine(filename, output_filename); + + // Execute ffmpeg and report errors. + pid_t child_pid = ::fork(); + if (child_pid < 0) { + return Status(error::Code::UNKNOWN, + StrCat("fork failed: ", strerror(errno))); + } + if (child_pid == 0) { + const int fd = + open(stderr_filename.c_str(), O_RDWR | O_CREAT | O_APPEND, 0600); + if (fd < 0) { + const int error = errno; + LOG(ERROR) << "FFmpeg stderr file could not be created: " + << strerror(error); + ::_exit(error); + } + close(STDERR_FILENO); + dup2(fd, STDERR_FILENO); + ExecuteFfmpeg(args); + } else { + int status_code; + if (::waitpid(child_pid, &status_code, 0) < 0) { + return Status(error::Code::UNKNOWN, + StrCat("waitpid failed: ", strerror(errno))); + } + if (status_code) { + return Status(error::Code::UNKNOWN, + StrCat("FFmpeg execution failed: ", status_code)); + } + + TF_QCHECK_OK(ReadInfoFile(stderr_filename, width, height, frames)) + << "Could not read FFmpeg stderr file: " << stderr_filename; + + string raw_data; + TF_QCHECK_OK(ReadFileToString(Env::Default(), output_filename, &raw_data)) + << "Could not read FFmpeg output file: " << output_filename; + output_data->resize(raw_data.size()); + std::copy_n(raw_data.data(), raw_data.size(), output_data->begin()); + + TF_QCHECK_OK(Env::Default()->DeleteFile(output_filename)) + << output_filename; + TF_QCHECK_OK(Env::Default()->DeleteFile(stderr_filename)) + << stderr_filename; + return Status::OK(); + } +} } // namespace ffmpeg } // namespace tensorflow diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc index 2871c1462894c6a4ddef63e9178272df0d14824c..85b61b26163d87a10d4e316720b4f633e038bbec 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc @@ -39,7 +39,7 @@ const char kTestMp3Filename[] = // Set to true via a command line flag iff the test is expected to have FFmpeg // installed. -mutex mu; +mutex mu(LINKER_INITIALIZED); bool should_ffmpeg_be_installed GUARDED_BY(mu) = false; string ParseTestFlags(int* argc, char** argv) { diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc index 7176f3b550679555d5ab3b70f2b360a90eaee253..36fc71794b06e0f3cb86c40b325ce50e8999c667 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc @@ -20,7 +20,10 @@ #include #include + +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/test.h" @@ -49,7 +52,7 @@ TEST(FfmpegLibTest, TestTempDirectoryThreading) { pool.Schedule([&mu, &temp_filenames, environment]() { std::array buffer; for (int32 j = 0; j < kStringsPerItem; ++j) { - buffer[j] = GetTempFilename("mp3"); + buffer[j] = io::GetTempFilename("mp3"); TF_QCHECK_OK(environment->DeleteFile(buffer[j])); } mutex_lock l(mu); diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h index f64007c81d74276d42c9d6ebd7c8f46cda6b7d72..c5ea1432bf8b61c87615074a93a45325371c4c87 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h +++ b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h @@ -24,16 +24,24 @@ namespace tensorflow { namespace ffmpeg { -// Gets a temp filename in an appropriate location. -string GetTempFilename(const string& extension); +// Cleans up a file on destruction. +class FileDeleter { + public: + explicit FileDeleter(const string& filename) : filename_(filename) {} + ~FileDeleter(); + + private: + const string filename_; +}; + +// Writes binary data to a file. +Status WriteFile(const string& filename, tensorflow::StringPiece contents); // Reads an audio file using ffmpeg and converts it into an array of samples in // [-1.0, 1.0]. If there are multiple channels in the audio then each frame will // contain a separate sample for each channel. Frames are ordered by time. -Status ReadAudioFile(const string& filename, - const string& audio_format_id, - int32 samples_per_second, - int32 channel_count, +Status ReadAudioFile(const string& filename, const string& audio_format_id, + int32 samples_per_second, int32 channel_count, std::vector* output_samples); // Creates an audio file using ffmpeg in a specific format. The samples are in @@ -45,6 +53,11 @@ Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second, int32 samples_per_second, int32 channel_count, const std::vector& samples, string* output_data); +// Reads an video file using ffmpeg adn converts it into a RGB24 in uint8 +// [frames, height, width, 3]. The w, h, and frames are obtained from ffmpeg. +Status ReadVideoFile(const string& filename, std::vector* output_data, + uint32* width, uint32* height, uint32* frames); + } // namespace ffmpeg } // namespace tensorflow diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 18b0b8b812c908cff62a241aa59b3a53021123f4..08b5a6ea48c2d4959af68a2ee9d27d21c6245457 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -19,7 +19,9 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py +from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py +from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader @@ -89,3 +91,19 @@ def encode_audio(audio, file_format=None, samples_per_second=None): ops.NotDifferentiable('EncodeAudio') + + +def decode_video(contents): + """Create an op that decodes the contents of a video file. + + Args: + contents: The binary contents of the video file to decode. This is a + scalar. + + Returns: + A rank-4 `Tensor` that has `[frames, height, width, 3]` RGB as output. + """ + return gen_decode_video_op_py.decode_video(contents) + + +ops.NotDifferentiable('DecodeVideo') diff --git a/tensorflow/contrib/ffmpeg/testdata/small.mp4 b/tensorflow/contrib/ffmpeg/testdata/small.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1fc478842f51e7519866f474a02ad605235bc6a6 Binary files /dev/null and b/tensorflow/contrib/ffmpeg/testdata/small.mp4 differ diff --git a/tensorflow/contrib/ffmpeg/testdata/small_100.bmp b/tensorflow/contrib/ffmpeg/testdata/small_100.bmp new file mode 100644 index 0000000000000000000000000000000000000000..61f53a2a21c933037f004d6ae4319dc6065fb886 Binary files /dev/null and b/tensorflow/contrib/ffmpeg/testdata/small_100.bmp differ diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index e8dad886a1409babdf4ea47b9cd05def1f1ce25e..5b659ddaa1386736eb8cc05a203ed1827ccd160e 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -276,6 +276,7 @@ py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:partitioned_variables", "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variable_scope", diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 3f592611830e40a30392239c85486a2fad15a2a2..4edc77f86ba786ca547b8d3842e2cf02833fbbac 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -65,6 +65,7 @@ See the @{$python/contrib.framework} guide. @@get_variable_full_name @@get_variables_to_restore @@get_variables +@@global_variable @@local_variable @@model_variable @@variable diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index 8ab8711db4650921e0d366a91adfe2f68b5a42f9..a18ff2320d99726bb355ff6179fc97a070c2fec7 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -24,12 +24,14 @@ import six # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework import ops from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes from tensorflow.python.framework.graph_util_impl import _extract_graph_summary from tensorflow.python.framework.graph_util_impl import _node_name -__all__ = ["fuse_op"] + +__all__ = ["fuse_op", "get_placeholders"] def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, @@ -91,7 +93,7 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, (n, cur_node)) if cur_node not in input_nodes_set: next_to_visit += name_to_input_name[cur_node] - else: + elif n not in reachable_by_input: nodes_post_output.append(n) # Add all nodes upto the input nodes @@ -126,3 +128,27 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out + + +def get_placeholders(graph): + """Get placeholders of a graph. + + Args: + graph: A tf.Graph. + Returns: + A list contains all placeholders of given graph. + + Raises: + TypeError: If `graph` is not a tensorflow graph. + """ + + if not isinstance(graph, ops.Graph): + raise TypeError("Input graph needs to be a Graph: %s" % graph) + + # For each placeholder() call, there is a corresponding + # operation of type 'Placeholder' registered to the graph. + # The return value (a Tensor) of placeholder() is the + # first output of this operation in fact. + operations = graph.get_operations() + result = [i.outputs[0] for i in operations if i.type == "Placeholder"] + return result diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py index 87b992e22e1ad3aa20389d0834eeb3a5972c676e..b8a6d109e19211d271c2b15bac66ddacd38fe395 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util_test.py +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -21,6 +21,9 @@ from tensorflow.contrib.framework.python.framework import graph_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -56,6 +59,41 @@ class GraphUtilTest(test.TestCase): self.assertEqual(fused_graph_def.node[2].name, 'D') self.assertEqual(fused_graph_def.node[3].name, 'E') + def testGraphUtilArtificialDependencyInjection(self): + graph_def = graph_pb2.GraphDef() + node_a = GetNewNode('A', 'Placeholder', []) + node_a1 = GetNewNode('A1', 'Placeholder', []) + node_b = GetNewNode('B', 'Op1', ['A']) + node_c = GetNewNode('C', 'Op1', ['B']) + node_d = GetNewNode('D', 'Op1', ['C']) + node_e = GetNewNode('E', 'Op1', ['D']) + graph_def.node.extend([node_a, node_a1, node_b, node_c, node_d, node_e]) + fused_graph_def = graph_util.fuse_op(graph_def, ['A', 'A1'], ['D'], + [types_pb2.DT_FLOAT], True, 'FusedOp', + 'Op2') + self.assertEqual(len(fused_graph_def.node), 5) + self.assertEqual(fused_graph_def.node[0].name, 'A') + self.assertEqual(fused_graph_def.node[1].name, 'A1') + self.assertEqual(fused_graph_def.node[2].name, 'FusedOp') + self.assertEqual(fused_graph_def.node[2].input[0], 'A') + self.assertEqual(fused_graph_def.node[2].op, 'Op2') + self.assertEqual(fused_graph_def.node[2].attr['_output_quantized'].b, True) + self.assertEqual(fused_graph_def.node[2].attr['_output_types'].list.type, + [types_pb2.DT_FLOAT]) + self.assertEqual(fused_graph_def.node[3].name, 'D') + self.assertEqual(fused_graph_def.node[4].name, 'E') + + +class GetPlaceholdersTest(test.TestCase): + + def test_get_placeholders(self): + with ops.Graph().as_default() as g: + placeholders = [array_ops.placeholder(dtypes.float32) for _ in range(5)] + results = graph_util.get_placeholders(g) + self.assertEqual( + sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access + sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py index a0667bd489213cf366e27114a91e8699ed9e7428..2375ee4f550616ff60d20b87b5773704d8fbbe1e 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py @@ -48,7 +48,7 @@ def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): tf.accumulate_n_v2([a, b, a]) # [[7, 4], [6, 14]] # Explicitly pass shape and type - tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) + tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) # [[7, 4], # [6, 14]] ``` @@ -93,7 +93,7 @@ def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): elif len(inputs) == 1 and name is not None: return array_ops.identity(inputs[0], name=name) elif context.in_eager_mode(): - # TemporaryVariable not currently supported in eager mode; fall back + # TemporaryVariable not currently supported in eager mode; fall back # onto AddN for now. # TODO(frreiss) remove this once the lifetime of eager variables gets # addressed @@ -101,7 +101,7 @@ def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): else: return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) -# The following code should eventually be merged into +# The following code should eventually be merged into # tensorflow/python/ops/math_grad.py @ops.RegisterGradient("AccumulateNV2") def _AddNGrad(op, grad): diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py index c2229bb8ad3d5b38321d16f150ed94175ab9bdbe..8f44698da851b48abf831e957c80fa1643a58bda 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into +"""Tests for new version of accumulate_n op that will eventually go into `ops.math_ops`. -These test cases spefically exercise the `eager` APIs. They need to be in a +These test cases spefically exercise the `eager` APIs. They need to be in a separate file from the remaining tests because eager mode is currently something you can turn on but can't turn off for the lifetime of the current process.""" from __future__ import absolute_import @@ -64,7 +64,7 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): np.random.seed(42) num_inputs = 3 input_vars = [ - resource_variable_ops.ResourceVariable(10.0 * np.random.random(), + resource_variable_ops.ResourceVariable(10.0 * np.random.random(), name="t%d" % i) for i in range(0, num_inputs) ] @@ -72,7 +72,7 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): def fn(first, second, third): return av2.accumulate_n_v2([first, second, third]) - grad_fn = backprop.gradients_function(fn) + grad_fn = backprop.gradients_function(fn) grad = grad_fn(input_vars[0], input_vars[1], input_vars[2]) self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 [elem.numpy() for elem in grad]) diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py index 3386e849d5cb8516ab3b1f6cb0429be3fc2fc960..b5e9f8df79262635bf579a6bf2260bc40c140c6f 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into +"""Tests for new version of accumulate_n op that will eventually go into `ops.math_ops`.""" from __future__ import absolute_import from __future__ import division @@ -102,21 +102,21 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): a = variables.Variable(np.array([0.1,0.2])) b = variables.Variable(np.array([[0.3],[0.4]])) - tf_val = av2.accumulate_n_v2([a,b]) + tf_val = av2.accumulate_n_v2([a,b]) def testWrongType(self): with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) b = variables.Variable(0.1, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) + tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) def testWrongTypeOneInput(self): # Scenario that used to trigger a bug, even when testWrongType() worked with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) + tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) if __name__ == "__main__": diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index b7668379686b4f0ba2a3e415ddb44b287659baaa..3f1ece4510578b5ac39849c577fffbb2a3be45a7 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -60,6 +60,7 @@ __all__ = ['add_model_variable', 'get_variable_full_name', 'get_variables_to_restore', 'get_variables', + 'global_variable', 'local_variable', 'model_variable', 'variable', @@ -147,20 +148,48 @@ def get_or_create_global_step(graph=None): return training_util.get_or_create_global_step(graph) -def local_variable(initial_value, validate_shape=True, name=None): - """Create variable and add it to `GraphKeys.LOCAL_VARIABLES` collection. +def local_variable(initial_value, + validate_shape=True, + name=None, + use_resource=None): + """Create a variable with a value and add it to `GraphKeys.LOCAL_VARIABLES`. Args: initial_value: See variables.Variable.__init__. validate_shape: See variables.Variable.__init__. name: See variables.Variable.__init__. + use_resource: If `True` use a ResourceVariable instead of a Variable. Returns: New variable. """ return variable_scope.variable( initial_value, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], - validate_shape=validate_shape, name=name) + validate_shape=validate_shape, + use_resource=use_resource, + name=name) + + +def global_variable(initial_value, + validate_shape=True, + name=None, + use_resource=None): + """Create a variable with a value and add it to `GraphKeys.GLOBAL_VARIABLES`. + + Args: + initial_value: See variables.Variable.__init__. + validate_shape: See variables.Variable.__init__. + name: See variables.Variable.__init__. + use_resource: If `True` use a ResourceVariable instead of a Variable. + Returns: + New variable. + """ + return variable_scope.variable( + initial_value, trainable=False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + validate_shape=validate_shape, + use_resource=use_resource, + name=name) @contrib_add_arg_scope @@ -412,7 +441,7 @@ def get_unique_variable(var_op_name): """ candidates = get_variables(scope=var_op_name) if not candidates: - raise ValueError('Couldnt find variable %s' % var_op_name) + raise ValueError('Couldn\'t find variable %s' % var_op_name) for candidate in candidates: if candidate.op.name == var_op_name: diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 6a74e4e8666e98ca3c97dc9ddd8a6c11613f708e..2f06df93acb0a4c0b36c68839ff531e3c22c5ee3 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import gfile @@ -102,6 +103,82 @@ class LocalVariableTest(test.TestCase): sess.run(variables_lib.local_variables_initializer()) self.assertAllEqual(a.eval(), [0] * 5) + def testResourceVariable(self): + a = variables_lib2.local_variable(0) + b = variables_lib2.local_variable(0, use_resource=True) + self.assertEqual(type(a), variables_lib.Variable) + self.assertEqual(type(b), resource_variable_ops.ResourceVariable) + + +class GlobalVariableTest(test.TestCase): + + def test_global_variable(self): + with self.test_session() as sess: + self.assertEquals([], variables_lib.global_variables()) + value0 = 42 + variables_lib2.global_variable(value0) + value1 = 43 + variables_lib2.global_variable(value1) + variables = variables_lib.global_variables() + self.assertEquals(2, len(variables)) + with self.assertRaisesOpError( + 'Attempting to use uninitialized value Variable'): + sess.run(variables) + variables_lib.variables_initializer(variables).run() + self.assertAllEqual(set([value0, value1]), set(sess.run(variables))) + + def testVariableNameAndShape(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable([1, 1, 1, 1, 1], name='a') + self.assertEquals(a.op.name, 'A/a') + self.assertListEqual(a.get_shape().as_list(), [5]) + self.assertListEqual([a], variables_lib.global_variables()) + + def testGlobalVariableNotInLocalVariables(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable(0) + self.assertFalse(a in variables_lib.local_variables()) + self.assertTrue(a in variables_lib.global_variables()) + + def testGlobalVariableInVariablesToRestore(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable(0) + self.assertFalse(a in variables_lib.local_variables()) + self.assertTrue(a in variables_lib2.get_variables_to_restore()) + + def testGetVariablesReturnsThem(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable(0) + with variable_scope.variable_scope('B'): + b = variables_lib2.global_variable(0) + self.assertEquals([a], variables_lib2.get_variables('A')) + self.assertEquals([b], variables_lib2.get_variables('B')) + + def testGetLocalVariablesDontReturnsThem(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + variables_lib2.global_variable(0) + with variable_scope.variable_scope('B'): + variables_lib2.global_variable(0) + self.assertEquals([], variables_lib2.get_local_variables('A')) + self.assertEquals([], variables_lib2.get_local_variables('B')) + + def testInitializedVariableValue(self): + with self.test_session() as sess: + a = variables_lib2.global_variable([0, 0, 0, 0, 0], name='a') + sess.run(variables_lib.global_variables_initializer()) + self.assertAllEqual(a.eval(), [0] * 5) + + def testResourceVariable(self): + a = variables_lib2.global_variable(0) + b = variables_lib2.global_variable(0, use_resource=True) + self.assertEqual(type(a), variables_lib.Variable) + self.assertEqual(type(b), resource_variable_ops.ResourceVariable) + class GlobalStepTest(test.TestCase): diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 88306094ab9947c9c78b03c0013f6afc88316803..5fec69ea4361a97c79ddc3188469e7ffb327f6cc 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -493,6 +493,8 @@ void LaunchFusedConv2DBiasActivationOp:: {{conv_input_rows, conv_input_cols}}, output_depth, {{filter_rows, filter_cols}}, + // TODO(yangzihao): Add support for arbitrary dilations for fused conv. + {{1, 1}}, // dilation_rows, dilation_cols {{row_stride, col_stride}}, {{padding_rows, padding_cols}}, conv_input->dtype(), diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h index dc43af11580ce5fda74ee25da6c151a5b89c7aee..fa7a3c03aa35c756252b22a004be91fa24c10e41 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h @@ -30,11 +30,12 @@ class FusedConvParameters : public ConvParameters { public: FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, int64 out_depths, const SpatialArray& filter, - const SpatialArray& stride, const SpatialArray& padding, - DataType dtype, int device_id, bool has_side_input, + const SpatialArray& dilation, const SpatialArray& stride, + const SpatialArray& padding, DataType dtype, + int device_id, bool has_side_input, ActivationMode activation_mode) - : ConvParameters(batch, in_depths, in, out_depths, filter, stride, - padding, dtype, device_id), + : ConvParameters(batch, in_depths, in, out_depths, filter, dilation, + stride, padding, dtype, device_id), activation_mode_(activation_mode), has_side_input_(has_side_input) { hash_code_ = Hash64Combine(hash_code_, has_side_input); diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index 887ebc5a6c35379476fa1a643c866d38e2b25699..6a56237f67c844a3daa546eb02d64c9e2658f639 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -52,6 +52,7 @@ REGISTER_OP("FusedConv2DBiasActivation") .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") .Attr("activation_mode: {'Relu'} = 'Relu'") + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](shape_inference::InferenceContext* c) { using shape_inference::ShapeHandle; using shape_inference::DimensionHandle; @@ -151,6 +152,11 @@ REGISTER_OP("FusedConv2DBiasActivation") kernel_height, kernel_width, input_channels % 4 ]` activation_mode: The activation applied to the output. Currently must be "Relu". + dilations: 1-D tensor of length 4. The dilation factor for each dimension + of `input`. If set to k > 1, there will be k-1 skipped cells between + each filter element on that dimension. The dimension order is determined + by the value of `data_format`, see above for details. Dilations in the + batch and depth dimensions must be 1. )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 1418c87023af0dbff890f46e10f0140d5b89e4b7..b355a79b1a5d967eb82a30d41c073bbb52e0364c 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -56,6 +56,7 @@ py_test( srcs = ["python/train_test.py"], srcs_version = "PY2AND3", deps = [ + ":features", ":namedtuples", ":train", "//tensorflow/contrib/framework:framework_py", @@ -82,6 +83,7 @@ py_library( deps = [ ":classifier_metrics", ":eval_utils", + ":sliced_wasserstein", ":summaries", "//tensorflow/python:util", ], @@ -116,6 +118,7 @@ py_library( deps = [ ":clip_weights", ":conditioning_utils", + ":random_tensor_pool", ":virtual_batchnorm", "//tensorflow/python:util", ], @@ -219,6 +222,37 @@ py_test( ], ) +py_library( + name = "random_tensor_pool", + srcs = [ + "python/features/python/random_tensor_pool.py", + "python/features/python/random_tensor_pool_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:util", + ], +) + +py_test( + name = "random_tensor_pool_test", + srcs = ["python/features/python/random_tensor_pool_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":random_tensor_pool", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//third_party/py/numpy", + ], +) + py_library( name = "virtual_batchnorm", srcs = [ @@ -470,6 +504,41 @@ py_test( ], ) +py_library( + name = "sliced_wasserstein", + srcs = [ + "python/eval/python/sliced_wasserstein.py", + "python/eval/python/sliced_wasserstein_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_test( + name = "sliced_wasserstein_test", + srcs = ["python/eval/python/sliced_wasserstein_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":sliced_wasserstein", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 3ab84780705b35567169bd76fd3485ad355ba9d8..4bca0a1d62a2b404c6783c7cfe3b5c67cfc58221 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -8,7 +8,8 @@ explicitly model the distribution and without writing an explicit loss. For example, the generator could learn to draw samples from the distribution of natural images. For more details on this technique, see ['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by -Goodfellow et al. +Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an +introduction. #### Usage ```python @@ -23,8 +24,8 @@ mix TFGAN, native TF, and other custom frameworks * Use already implemented [GAN losses and penalties](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/losses_impl.py) (ex Wasserstein loss, gradient penalty, mutual information penalty, etc) * [Monitor and visualize](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/summaries_impl.py) GAN progress during training, and [evaluate](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py) them * Use already-implemented [tricks](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/) to stabilize and improve training -* Develop based on examples of common GAN setups -* Use the TFGAN-backed tf.Learn Estimator to easily train a GAN model +* Develop based on examples of [common GAN setups](https://github.com/tensorflow/models/tree/master/research/gan/) +* Use the TFGAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model * Improvements in TFGAN infrastructure will automatically benefit your TFGAN project * Stay up-to-date with research as we add more algorithms @@ -51,7 +52,7 @@ network to evaluate your unconditional generative model. You can also use your own pretrained classifier for more specific performance numbers, or use other methods for evaluating conditional generative models. -* examples (coming soon): +* [examples](https://github.com/tensorflow/models/tree/master/research/gan/) and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TFGAN to make GAN training easier, or use the more complicated examples to jumpstart your own project. These include unconditional and conditional GANs, InfoGANs, adversarial losses on existing networks, and image-to-image translation. diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index dff361fdc42708ea69999c2def4721f9d49fcf14..f1946c7f925660eae3aaa650c437e03da1f33d6c 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN is a lightweight library for training and evaluating GANs. + +In addition to providing the infrastructure for easily training and evaluating +GANS, this library contains modules for a TFGAN-backed Estimator, +evaluation metrics, features (such as virtual batch normalization), and losses. +Please see README.md for details and usage. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index 8c4a18228039cb4f2c06e0333f4b8408f1f631e9..c9f7bc61b25230e4159cf8cbc7c9cceead0aa706 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN estimator module. + +GANEstimator provides all the infrastructure support of a TensorFlow Estimator +with the feature support of TFGAN. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index e89993991a389d68254a95aded2d771f4c2627be..d3dca3d9e75fe1ef3be67143e18c0b51e84ad24c 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import enum from tensorflow.contrib.framework.python.ops import variables as variable_lib @@ -29,6 +30,7 @@ from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope +from tensorflow.python.util import tf_inspect as inspect __all__ = [ @@ -76,7 +78,7 @@ class GANEstimator(estimator.Estimator): return logits # Create GAN estimator. - gan_estimator = estimator.GANEstimator( + gan_estimator = tfgan.estimator.GANEstimator( model_dir, generator_fn=generator_fn, discriminator_fn=discriminator_fn, @@ -105,6 +107,7 @@ class GANEstimator(estimator.Estimator): discriminator_loss_fn=None, generator_optimizer=None, discriminator_optimizer=None, + get_hooks_fn=None, add_summaries=None, use_loss_summaries=True, config=None): @@ -116,7 +119,10 @@ class GANEstimator(estimator.Estimator): to continue training a previously saved model. generator_fn: A python function that takes a Tensor, Tensor list, or Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TFGAN` for more details and examples. + generator. See `TFGAN` for more details and examples. Additionally, if + it has an argument called `mode`, the Estimator's `mode` will be passed + in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch + normalization. discriminator_fn: A python function that takes the output of `generator_fn` or real data in the GAN setup, and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details @@ -132,6 +138,10 @@ class GANEstimator(estimator.Estimator): work. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. These hooks are run on the generator and discriminator + train ops, and can be used to implement the GAN training scheme. + Defaults to `train.get_sequential_train_hooks()`. add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. @@ -146,7 +156,7 @@ class GANEstimator(estimator.Estimator): else discriminator_optimizer) gan_head = head_lib.gan_head( generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries) + use_loss_summaries, get_hooks_fn=get_hooks_fn) return _gan_model_fn( features, labels, mode, generator_fn, discriminator_fn, gan_head, add_summaries) @@ -155,11 +165,6 @@ class GANEstimator(estimator.Estimator): model_fn=_model_fn, model_dir=model_dir, config=config) -def _use_check_shapes(real_data): - """Determines whether TFGAN should check Tensor shapes.""" - return isinstance(real_data, ops.Tensor) - - def _gan_model_fn( features, labels, @@ -225,16 +230,19 @@ def _gan_model_fn( labels=None) -def _make_train_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries): - """Make a `GANModel` for training.""" +def _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, mode): + """Make a `GANModel`, and optionally pass in `mode`.""" + # If `generator_fn` has an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial(generator_fn, mode=mode) gan_model = tfgan_train.gan_model( generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope=generator_scope, - check_shapes=_use_check_shapes(real_data)) + check_shapes=False) if add_summaries: if not isinstance(add_summaries, (tuple, list)): add_summaries = [add_summaries] @@ -245,15 +253,28 @@ def _make_train_gan_model(generator_fn, discriminator_fn, real_data, return gan_model +def _make_train_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries): + """Make a `GANModel` for training.""" + return _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, + model_fn_lib.ModeKeys.TRAIN) + + def _make_eval_gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope, add_summaries): """Make a `GANModel` for evaluation.""" - return _make_train_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries) + return _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, + model_fn_lib.ModeKeys.EVAL) def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" + # If `generator_fn` has an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial(generator_fn, + mode=model_fn_lib.ModeKeys.PREDICT) with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access generated_data = generator_fn(generator_inputs) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 1bfdce9ee94d4d05d5186cd999361662bc0e3f85..e752f0bcccda418b79d4fdabb27807394cbbb425 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -48,7 +48,8 @@ from tensorflow.python.training import training from tensorflow.python.training import training_util -def generator_fn(noise_dict): +def generator_fn(noise_dict, mode): + del mode noise = noise_dict['x'] return layers.fully_connected(noise, noise.shape[1].value) @@ -90,7 +91,6 @@ def mock_head(testcase, expected_generator_inputs, expected_real_data, generator_var_names, set([x.name for x in gan_model.generator_variables])) testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name) - testcase.assertEqual(generator_fn, gan_model.generator_fn) testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data) # TODO(joelshor): Add check on `discriminator_real_outputs`. # TODO(joelshor): Add check on `discriminator_gen_outputs`. diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 204c646e194319c0e63599da0b2a4909ef270ef3..a21358c50bbdb4a1a929b0c5bc322cec4c9923b5 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -71,7 +71,7 @@ class GANHead(head._Head): # pylint: disable=protected-access def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, - get_hooks_fn=tfgan_train.get_sequential_train_hooks(), + get_hooks_fn=None, name=None): """`Head` for GAN training. @@ -86,10 +86,12 @@ class GANHead(head._Head): # pylint: disable=protected-access use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. + of hooks. Defaults to `train.get_sequential_train_hooks()` name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ + if get_hooks_fn is None: + get_hooks_fn = tfgan_train.get_sequential_train_hooks() # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: diff --git a/tensorflow/contrib/gan/python/eval/__init__.py b/tensorflow/contrib/gan/python/eval/__init__.py index bb8046187807d0cc584f7174eb9aac578855c110..f86b8513053a45f9830411f7df2c32d1f36a97b2 100644 --- a/tensorflow/contrib/gan/python/eval/__init__.py +++ b/tensorflow/contrib/gan/python/eval/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN evaluation module. + +This module supports techniques such as Inception Score, Frechet Inception +distance, and Sliced Wasserstein distance. +""" # pylint: disable=,wildcard-import,unused-import from __future__ import absolute_import @@ -22,10 +26,12 @@ from __future__ import print_function # Collapse eval into a single namespace. from tensorflow.contrib.gan.python.eval.python import classifier_metrics from tensorflow.contrib.gan.python.eval.python import eval_utils +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein from tensorflow.contrib.gan.python.eval.python import summaries from tensorflow.contrib.gan.python.eval.python.classifier_metrics import * from tensorflow.contrib.gan.python.eval.python.eval_utils import * +from tensorflow.contrib.gan.python.eval.python.sliced_wasserstein import * from tensorflow.contrib.gan.python.eval.python.summaries import * # pylint: enable=wildcard-import,unused-import @@ -33,7 +39,10 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'classifier_metrics', + 'sliced_wasserstein_distance', 'summaries', 'eval_utils', -] + classifier_metrics.__all__ + summaries.__all__ + eval_utils.__all__ +] + ( + classifier_metrics.__all__ + sliced_wasserstein.__all__ + + summaries.__all__ + eval_utils.__all__) remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index bb65f05b5a17e9a872e41d1dcb05aeb3cd6f6f40..82293b575aefa198a618ae7286ca24ebabd6987d 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -57,8 +57,10 @@ __all__ = [ 'run_inception', 'inception_score', 'classifier_score', + 'classifier_score_from_logits', 'frechet_inception_distance', 'frechet_classifier_distance', + 'frechet_classifier_distance_from_activations', 'INCEPTION_DEFAULT_IMAGE_SIZE', ] @@ -222,13 +224,13 @@ def run_inception(images, image_size: Required image width and height. See unit tests for the default values. input_tensor: Name of input Tensor. - output_tensor: Name of output Tensor. This function will compute activations - at the specified layer. Examples include INCEPTION_V3_OUTPUT and - INCEPTION_V3_FINAL_POOL which would result in this function computing + output_tensor: Name or list of output Tensors. This function will compute + activations at the specified layer. Examples include INCEPTION_V3_OUTPUT + and INCEPTION_V3_FINAL_POOL which would result in this function computing the final logits or the penultimate pooling layer. Returns: - Logits. + Tensor or Tensors corresponding to computed `output_tensor`. Raises: ValueError: If images are not the correct size. @@ -244,8 +246,14 @@ def run_inception(images, activations = run_image_classifier(images, graph_def, input_tensor, output_tensor) - if array_ops.rank(activations) != 2: - activations = layers.flatten(activations) + if isinstance(activations, list): + for i, activation in enumerate(activations): + if array_ops.rank(activation) != 2: + activations[i] = layers.flatten(activation) + else: + if array_ops.rank(activations) != 2: + activations = layers.flatten(activations) + return activations @@ -257,23 +265,26 @@ def run_image_classifier(tensor, graph_def, input_tensor, tensor: An Input tensor. graph_def: A GraphDef proto. input_tensor: Name of input tensor in graph def. - output_tensor: Name of output tensor in graph def. + output_tensor: A tensor name or list of tensor names in graph def. scope: Name scope for classifier. Returns: - Classifier output. Shape depends on the classifier used, but is often - [batch, classes]. + Classifier output if `output_tensor` is a string, or a list of outputs if + `output_tensor` is a list. Raises: - ValueError: If `image_size` is not `None`, and `tensor` are not the correct - size. + ValueError: If `input_tensor` or `output_tensor` aren't in the graph_def. """ input_map = {input_tensor: tensor} - return_elements = [output_tensor] - classifier_output = importer.import_graph_def( - graph_def, input_map, return_elements, name=scope)[0] + is_singleton = isinstance(output_tensor, str) + if is_singleton: + output_tensor = [output_tensor] + classifier_outputs = importer.import_graph_def( + graph_def, input_map, output_tensor, name=scope) + if is_singleton: + classifier_outputs = classifier_outputs[0] - return classifier_output + return classifier_outputs def classifier_score(images, classifier_fn, num_batches=1): @@ -312,6 +323,30 @@ def classifier_score(images, classifier_fn, num_batches=1): swap_memory=True, name='RunClassifier') logits = array_ops.concat(array_ops.unstack(logits), 0) + + return classifier_score_from_logits(logits) + + +def classifier_score_from_logits(logits): + """Classifier score for evaluating a conditional generative model. + + This is based on the Inception Score, but for an arbitrary classifier. + + This technique is described in detail in https://arxiv.org/abs/1606.03498. In + summary, this function calculates + + exp( E[ KL(p(y|x) || p(y)) ] ) + + which captures how different the network's classification prediction is from + the prior distribution over classes. + + Args: + logits: A 2D Tensor of logits. + + Returns: + The classifier score. A floating-point scalar of the same type as the output + of `logits`. + """ logits.shape.assert_has_rank(2) # Use maximum precision for best results. @@ -436,31 +471,71 @@ def frechet_classifier_distance(real_images, swap_memory=True, name='RunClassifier') - activations_dtype = activations.dtype # Split the activations by the real and generated images. real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) - if activations_dtype != dtypes.float64: - real_a = math_ops.to_double(real_a) - gen_a = math_ops.to_double(gen_a) - real_a.shape.assert_has_rank(2) - gen_a.shape.assert_has_rank(2) + return frechet_classifier_distance_from_activations(real_a, gen_a) + + +def frechet_classifier_distance_from_activations( + real_activations, generated_activations): + """Classifier distance for evaluating a generative model. + + This is based on the Frechet Inception distance, but for an arbitrary + classifier. + + This technique is described in detail in https://arxiv.org/abs/1706.08500. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + Args: + real_activations: Real images to use to compute Frechet Inception distance. + generated_activations: Generated images to use to compute Frechet Inception + distance. + + Returns: + The Frechet Inception distance. A floating-point scalar of the same type + as the output of the activations. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) # Compute mean and covariance matrices of activations. - m = math_ops.reduce_mean(real_a, 0) - m_v = math_ops.reduce_mean(gen_a, 0) - num_examples = math_ops.to_double(array_ops.shape(real_a)[0]) + m = math_ops.reduce_mean(real_activations, 0) + m_v = math_ops.reduce_mean(generated_activations, 0) + num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T + real_centered = real_activations - m sigma = math_ops.matmul( - real_a - m, real_a - m, transpose_a=True) / (num_examples - 1) + real_centered, real_centered, transpose_a=True) / (num_examples - 1) + gen_centered = generated_activations - m_v sigma_v = math_ops.matmul( - gen_a - m_v, gen_a - m_v, transpose_a=True) / (num_examples - 1) + gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) # Find the Tr(sqrt(sigma sigma_v)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 92e0a995748c1c4c2ddfff0daae59be5a6eaefb4..1e18c699ba93b5f524341c65d0a2db84556b65a2 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -190,6 +190,23 @@ class ClassifierMetricsTest(test.TestCase): # Check that none of the model variables are trainable. self.assertListEqual([], variables.trainable_variables()) + def test_run_inception_multiple_outputs(self): + """Test `run_inception` graph construction with multiple outputs.""" + batch_size = 3 + img = array_ops.ones([batch_size, 299, 299, 3]) + logits, pool = _run_with_mock( + classifier_metrics.run_inception, img, + output_tensor=[classifier_metrics.INCEPTION_OUTPUT, + classifier_metrics.INCEPTION_FINAL_POOL]) + + self.assertTrue(isinstance(logits, ops.Tensor)) + self.assertTrue(isinstance(pool, ops.Tensor)) + logits.shape.assert_is_compatible_with([batch_size, 1001]) + pool.shape.assert_is_compatible_with([batch_size, 2048]) + + # Check that none of the model variables are trainable. + self.assertListEqual([], variables.trainable_variables()) + def test_inception_score_graph(self): """Test `inception_score` graph construction.""" score = _run_with_mock(classifier_metrics.inception_score, diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py new file mode 100644 index 0000000000000000000000000000000000000000..523968bed91f1021ae629bf52c405cf5c2d7b917 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model evaluation tools for TFGAN.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.eval.python.sliced_wasserstein_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = sliced_wasserstein_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..9bebcacbe46d85fc4226c4275b71b3ecbde57a97 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -0,0 +1,282 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of Sliced Wasserstein Distance. + +Proposed in https://arxiv.org/abs/1710.10196 and the official Theano +implementation that we used as reference can be found here: +https://github.com/tkarras/progressive_growing_of_gans + +Note: this is not an exact distance but an approximation through random +projections. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import script_ops + +__all__ = ['sliced_wasserstein_distance'] +_GAUSSIAN_FILTER = np.float32([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [ + 6, 24, 36, 24, 6 +], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]]).reshape([5, 5, 1, 1]) / 256.0 + + +def _laplacian_pyramid(batch, num_levels): + """Compute a Laplacian pyramid. + + Args: + batch: (tensor) The batch of images (batch, height, width, channels). + num_levels: (int) Desired number of hierarchical levels. + Returns: + List of tensors from the highest to lowest resolution. + """ + gaussian_filter = constant_op.constant(_GAUSSIAN_FILTER) + + def spatial_conv(batch, gain): + s = array_ops.shape(batch) + padded = array_ops.pad(batch, [[0, 0], [2, 2], [2, 2], [0, 0]], 'REFLECT') + xt = array_ops.transpose(padded, [0, 3, 1, 2]) + xt = array_ops.reshape(xt, [s[0] * s[3], s[1] + 4, s[2] + 4, 1]) + conv_out = nn_ops.conv2d(xt, gaussian_filter * gain, [1] * 4, 'VALID') + conv_xt = array_ops.reshape(conv_out, [s[0], s[3], s[1], s[2]]) + conv_xt = array_ops.transpose(conv_xt, [0, 2, 3, 1]) + return conv_xt + + def pyr_down(batch): # matches cv2.pyrDown() + return spatial_conv(batch, 1)[:, ::2, ::2] + + def pyr_up(batch): # matches cv2.pyrUp() + s = array_ops.shape(batch) + zeros = array_ops.zeros([3 * s[0], s[1], s[2], s[3]]) + res = array_ops.concat([batch, zeros], 0) + res = array_ops.batch_to_space(res, crops=[[0, 0], [0, 0]], block_size=2) + res = spatial_conv(res, 4) + return res + + pyramid = [math_ops.to_float(batch)] + for _ in range(1, num_levels): + pyramid.append(pyr_down(pyramid[-1])) + pyramid[-2] -= pyr_up(pyramid[-1]) + return pyramid + + +def _batch_to_patches(batch, patches_per_image, patch_size): + """Extract patches from a batch. + + Args: + batch: (tensor) The batch of images (batch, height, width, channels). + patches_per_image: (int) Number of patches to extract per image. + patch_size: (int) Size of the patches (size, size, channels) to extract. + Returns: + Tensor (batch*patches_per_image, patch_size, patch_size, channels) of + patches. + """ + + def py_func_random_patches(batch): + """Numpy wrapper.""" + batch_size, height, width, channels = batch.shape + patch_count = patches_per_image * batch_size + hs = patch_size // 2 + # Randomly pick patches. + patch_id, y, x, chan = np.ogrid[0:patch_count, -hs:hs + 1, -hs:hs + 1, 0:3] + img_id = patch_id // patches_per_image + # pylint: disable=g-no-augmented-assignment + # Need explicit addition for broadcast to work properly. + y = y + np.random.randint(hs, height - hs, size=(patch_count, 1, 1, 1)) + x = x + np.random.randint(hs, width - hs, size=(patch_count, 1, 1, 1)) + # pylint: enable=g-no-augmented-assignment + idx = ((img_id * height + y) * width + x) * channels + chan + patches = batch.flat[idx] + return patches + + patches = script_ops.py_func( + py_func_random_patches, [batch], batch.dtype, stateful=False) + return patches + + +def _normalize_patches(patches): + """Normalize patches by their mean and standard deviation. + + Args: + patches: (tensor) The batch of patches (batch, size, size, channels). + Returns: + Tensor (batch, size, size, channels) of the normalized patches. + """ + patches = array_ops.concat(patches, 0) + mean, variance = nn.moments(patches, [1, 2, 3], keep_dims=True) + patches = (patches - mean) / math_ops.sqrt(variance) + return array_ops.reshape(patches, [array_ops.shape(patches)[0], -1]) + + +def _sort_rows(matrix, num_rows): + """Sort matrix rows by the last column. + + Args: + matrix: a matrix of values (row,col). + num_rows: (int) number of sorted rows to return from the matrix. + Returns: + Tensor (num_rows, col) of the sorted matrix top K rows. + """ + tmatrix = array_ops.transpose(matrix, [1, 0]) + sorted_tmatrix = nn_ops.top_k(tmatrix, num_rows)[0] + return array_ops.transpose(sorted_tmatrix, [1, 0]) + + +def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim): + """Compute the approximate sliced Wasserstein distance. + + Args: + a: (matrix) Distribution "a" of samples (row, col). + b: (matrix) Distribution "b" of samples (row, col). + random_sampling_count: (int) Number of random projections to average. + random_projection_dim: (int) Dimension of the random projection space. + Returns: + Float containing the approximate distance between "a" and "b". + """ + s = array_ops.shape(a) + means = [] + for _ in range(random_sampling_count): + # Random projection matrix. + proj = random_ops.random_normal( + [array_ops.shape(a)[1], random_projection_dim]) + proj *= math_ops.rsqrt( + math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True)) + # Project both distributions and sort them. + proj_a = math_ops.matmul(a, proj) + proj_b = math_ops.matmul(b, proj) + proj_a = _sort_rows(proj_a, s[0]) + proj_b = _sort_rows(proj_b, s[0]) + # Pairwise Wasserstein distance. + wdist = math_ops.reduce_mean(math_ops.abs(proj_a - proj_b)) + means.append(wdist) + return math_ops.reduce_mean(means) + + +def _sliced_wasserstein_svd(a, b): + """Compute the approximate sliced Wasserstein distance using an SVD. + + This is not part of the paper, it's a variant with possibly more accurate + measure. + + Args: + a: (matrix) Distribution "a" of samples (row, col). + b: (matrix) Distribution "b" of samples (row, col). + Returns: + Float containing the approximate distance between "a" and "b". + """ + s = array_ops.shape(a) + # Random projection matrix. + sig, u = linalg_ops.svd(array_ops.concat([a, b], 0))[:2] + proj_a, proj_b = array_ops.split(u * sig, 2, axis=0) + proj_a = _sort_rows(proj_a[:, ::-1], s[0]) + proj_b = _sort_rows(proj_b[:, ::-1], s[0]) + # Pairwise Wasserstein distance. + wdist = math_ops.reduce_mean(math_ops.abs(proj_a - proj_b)) + return wdist + + +def sliced_wasserstein_distance(real_images, + fake_images, + resolution_min=16, + patches_per_image=64, + patch_size=7, + random_sampling_count=1, + random_projection_dim=7 * 7 * 3, + use_svd=False): + """Compute the Wasserstein distance between two distributions of images. + + Note that measure vary with the number of images. Use 8192 images to get + numbers comparable to the ones in the original paper. + + Args: + real_images: (tensor) Real images (batch, height, width, channels). + fake_images: (tensor) Fake images (batch, height, width, channels). + resolution_min: (int) Minimum resolution for the Laplacion pyramid. + patches_per_image: (int) Number of patches to extract per image per + Laplacian level. + patch_size: (int) Width of a square patch. + random_sampling_count: (int) Number of random projections to average. + random_projection_dim: (int) Dimension of the random projection space. + use_svd: experimental method to compute a more accurate distance. + Returns: + List of tuples (distance_real, distance_fake) for each level of the + Laplacian pyramid from the highest resoluion to the lowest. + distance_real is the Wasserstein distance between real images + distance_fake is the Wasserstein distance between real and fake images. + Raises: + ValueError: If the inputs shapes are incorrect. Input tensor dimensions + (batch, height, width, channels) are expected to be known at graph + construction time. In addition height and width must be the same and the + number of colors should be exactly 3. Real and fake images must have the + same size. + """ + height = real_images.shape[1] + real_images.shape.assert_is_compatible_with([None, None, height, 3]) + fake_images.shape.assert_is_compatible_with(real_images.shape) + + # Select resolutions. + resolution_full = int(height) + resolution_min = min(resolution_min, resolution_full) + resolution_max = resolution_full + # Base loss of detail. + resolutions = [ + 2**i + for i in range( + int(np.log2(resolution_max)), + int(np.log2(resolution_min)) - 1, -1) + ] + + # Gather patches for each level of the Laplacian pyramids. + patches_real, patches_fake, patches_test = ( + [[] for _ in resolutions] for _ in range(3)) + for lod, level in enumerate( + _laplacian_pyramid(real_images, len(resolutions))): + patches_real[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + patches_test[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + + for lod, level in enumerate( + _laplacian_pyramid(fake_images, len(resolutions))): + patches_fake[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + + for lod in range(len(resolutions)): + for patches in [patches_real, patches_test, patches_fake]: + patches[lod] = _normalize_patches(patches[lod]) + + # Evaluate scores. + scores = [] + for lod in range(len(resolutions)): + if not use_svd: + scores.append( + (_sliced_wasserstein(patches_real[lod], patches_test[lod], + random_sampling_count, random_projection_dim), + _sliced_wasserstein(patches_real[lod], patches_fake[lod], + random_sampling_count, random_projection_dim))) + else: + scores.append( + (_sliced_wasserstein_svd(patches_real[lod], patches_test[lod]), + _sliced_wasserstein_svd(patches_real[lod], patches_fake[lod]))) + return scores diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b960af28eaa969079b72c7aabcde2ad6cd1f5c68 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Sliced Wasserstein Distance.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import ndimage +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein_impl as swd +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class ClassifierMetricsTest(test.TestCase): + + def test_laplacian_pyramid(self): + # The numpy/scipy code for reference estimation comes from: + # https://github.com/tkarras/progressive_growing_of_gans + gaussian_filter = np.float32([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [ + 6, 24, 36, 24, 6 + ], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]]) / 256.0 + + def np_pyr_down(minibatch): # matches cv2.pyrDown() + assert minibatch.ndim == 4 + return ndimage.convolve( + minibatch, + gaussian_filter[np.newaxis, np.newaxis, :, :], + mode='mirror')[:, :, ::2, ::2] + + def np_pyr_up(minibatch): # matches cv2.pyrUp() + assert minibatch.ndim == 4 + s = minibatch.shape + res = np.zeros((s[0], s[1], s[2] * 2, s[3] * 2), minibatch.dtype) + res[:, :, ::2, ::2] = minibatch + return ndimage.convolve( + res, + gaussian_filter[np.newaxis, np.newaxis, :, :] * 4.0, + mode='mirror') + + def np_laplacian_pyramid(minibatch, num_levels): + # Note: there's a bug in the original SWD, fixed repeatability. + pyramid = [minibatch.astype('f').copy()] + for _ in range(1, num_levels): + pyramid.append(np_pyr_down(pyramid[-1])) + pyramid[-2] -= np_pyr_up(pyramid[-1]) + return pyramid + + data = np.random.normal(size=[256, 3, 32, 32]).astype('f') + pyramid = np_laplacian_pyramid(data, 3) + data_tf = array_ops.placeholder(dtypes.float32, [256, 32, 32, 3]) + pyramid_tf = swd._laplacian_pyramid(data_tf, 3) + with self.test_session() as sess: + pyramid_tf = sess.run( + pyramid_tf, feed_dict={ + data_tf: data.transpose(0, 2, 3, 1) + }) + for x in range(3): + self.assertAllClose( + pyramid[x].transpose(0, 2, 3, 1), pyramid_tf[x], atol=1e-6) + + def test_sliced_wasserstein_distance(self): + """Test the distance.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 32, 3]) + wfunc = swd.sliced_wasserstein_distance(d1, d2) + with self.test_session() as sess: + wscores = [sess.run(x) for x in wfunc] + self.assertAllClose( + np.array([0.014, 0.014], 'f'), + np.array([x[0] for x in wscores], 'f'), + rtol=0.1) + self.assertAllClose( + np.array([0.014, 0.020], 'f'), + np.array([x[1] for x in wscores], 'f'), + rtol=0.1) + + def test_sliced_wasserstein_distance_svd(self): + """Test the distance.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 32, 3]) + wfunc = swd.sliced_wasserstein_distance(d1, d2, use_svd=True) + with self.test_session() as sess: + wscores = [sess.run(x) for x in wfunc] + self.assertAllClose( + np.array([0.013, 0.013], 'f'), + np.array([x[0] for x in wscores], 'f'), + rtol=0.15) + self.assertAllClose( + np.array([0.014, 0.019], 'f'), + np.array([x[1] for x in wscores], 'f'), + rtol=0.15) + + def test_swd_mismatched(self): + """Test the inputs mismatched shapes are detected.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 31, 3]) + d3 = random_ops.random_normal([256, 31, 32, 3]) + d4 = random_ops.random_normal([255, 32, 32, 3]) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d2) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d3) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d4) + + def test_swd_not_rgb(self): + """Test that only RGB is supported.""" + d1 = random_ops.random_uniform([256, 32, 32, 1]) + d2 = random_ops.random_normal([256, 32, 32, 1]) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py index 6d0972f8db418d6fcf517cc6f7e96093ae08a9e4..4816daf760143af9f1502873b123ffad8e5ec8ce 100644 --- a/tensorflow/contrib/gan/python/features/__init__.py +++ b/tensorflow/contrib/gan/python/features/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN features module. + +This module includes support for virtual batch normalization, buffer replay, +conditioning, etc. +""" from __future__ import absolute_import from __future__ import division @@ -22,10 +26,12 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.features.python import clip_weights from tensorflow.contrib.gan.python.features.python import conditioning_utils +from tensorflow.contrib.gan.python.features.python import random_tensor_pool from tensorflow.contrib.gan.python.features.python import virtual_batchnorm from tensorflow.contrib.gan.python.features.python.clip_weights import * from tensorflow.contrib.gan.python.features.python.conditioning_utils import * +from tensorflow.contrib.gan.python.features.python.random_tensor_pool import * from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * # pylint: enable=unused-import,wildcard-import @@ -33,5 +39,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = clip_weights.__all__ _allowed_symbols += conditioning_utils.__all__ +_allowed_symbols += random_tensor_pool.__all__ _allowed_symbols += virtual_batchnorm.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py index 030e37ec679ec58e3b534fd3644ffe1d23173404..2b7bb5f14e7f3d1b3f913d3426efaaae19079ffb 100644 --- a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py +++ b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tfgan.python.features.clip_weights.""" +"""Tests for features.clip_weights.""" from __future__ import absolute_import from __future__ import division @@ -31,17 +31,18 @@ class ClipWeightsTest(test.TestCase): """Tests for `discriminator_weight_clip`.""" def setUp(self): + super(ClipWeightsTest, self).setUp() self.variables = [variables.Variable(2.0)] self.tuple = collections.namedtuple( 'VarTuple', ['discriminator_variables'])(self.variables) def _test_weight_clipping_helper(self, use_tuple): - loss = self.variables[0] * 2.0 + loss = self.variables[0] opt = training.GradientDescentOptimizer(1.0) if use_tuple: - opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1) + opt_clip = clip_weights.clip_variables(opt, self.variables, 0.1) else: - opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1) + opt_clip = clip_weights.clip_discriminator_weights(opt, self.tuple, 0.1) train_op1 = opt.minimize(loss, var_list=self.variables) train_op2 = opt_clip.minimize(loss, var_list=self.variables) @@ -72,10 +73,14 @@ class ClipWeightsTest(test.TestCase): clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1) else: with self.assertRaisesRegexp(ValueError, 'must be positive'): - clip_weights.clip_weights(opt, self.variables, weight_clip=-1) + clip_weights.clip_variables(opt, self.variables, weight_clip=-1) def test_incorrect_weight_clip_value_argsonly(self): self._test_incorrect_weight_clip_value_helper(False) def test_incorrect_weight_clip_value_tuple(self): self._test_incorrect_weight_clip_value_helper(True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..ca904971fa8cb0440d3e0c9060f13cc214c9eaad --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py @@ -0,0 +1,35 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A tensor pool stores values from an input tensor and returns a stored one. + +See the following papers for more details. +1) `Learning from simulated and unsupervised images through adversarial + training` (https://arxiv.org/abs/1612.07828). +2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial + Networks` (https://arxiv.org/abs/1703.10593). +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import random_tensor_pool_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = random_tensor_pool_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..9d10db0f5a3d09dc4dd7d8b1c97c16c29808547c --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py @@ -0,0 +1,135 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A tensor pool stores values from an input tensor and returns a stored one. + +We use this to keep a history of values created by a generator, such that +a discriminator can randomly be trained on some older samples, not just the +current one. This can help to not let the discriminator get too far ahead of the +generator and also to keep the system from oscilating, if the discriminator +forgets too fast what past samples from the generator looked like. + +See the following papers for more details. +1) `Learning from simulated and unsupervised images through adversarial + training` (https://arxiv.org/abs/1612.07828). +2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial + Networks` (https://arxiv.org/abs/1703.10593). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import random_ops + +__all__ = [ + 'tensor_pool', +] + + +def _to_tuple(x): + if isinstance(x, (list, tuple)): + return tuple(x) + return (x,) + + +def tensor_pool(input_values, + pool_size=50, + pooling_probability=0.5, + name='tensor_pool'): + """Queue storing input values and returning random previously stored ones. + + Every time the returned `output_value` is evaluated, `input_value` is + evaluated and its value either directly returned (with + `1-pooling_probability`) or stored in the pool and a random one of the samples + currently in the pool is popped and returned. As long as the pool in not fully + filled, the input_value is always directly returned, as well as stored in the + pool. Note during inference / testing, it may be appropriate to set + `pool_size` = 0 or `pooling_probability` = 0. + + Args: + input_values: A `Tensor`, or a list or tuple of `Tensor`s from which to read + values to be pooled. + pool_size: An integer specifying the maximum size of the pool. Defaults to + 50. + pooling_probability: A float `Tensor` specifying the probability of getting + a value from the pool, as opposed to just the current input. + name: A string prefix for the name scope for all tensorflow ops. + + Returns: + A `Tensor`, or a list or tuple of `Tensor`s (according to the type ofx + `input_values`) which is with given probability either the `input_values` or + a randomly chosen sample that was previously inserted in the pool. + + Raises: + ValueError: If `pool_size` is negative. + """ + pool_size = int(pool_size) + if pool_size < 0: + raise ValueError('`pool_size` is negative.') + elif pool_size == 0: + return input_values + + original_input_values = input_values + input_values = _to_tuple(input_values) + + with ops.name_scope( + '{}_pool_queue'.format(name), + values=input_values + (pooling_probability,)): + pool_queue = data_flow_ops.RandomShuffleQueue( + capacity=pool_size, + min_after_dequeue=0, + dtypes=[v.dtype for v in input_values], + shapes=None) + + # In pseudeo code this code does the following: + # if not pool_full: + # enqueue(input_values) + # return input_values + # else + # dequeue_values = dequeue_random_sample() + # enqueue(input_values) + # if rand() < pooling_probability: + # return dequeue_values + # else + # return input_values + + def _get_input_value_pooled(): + enqueue_op = pool_queue.enqueue(input_values) + with ops.control_dependencies([enqueue_op]): + return tuple(array_ops.identity(v) for v in input_values) + + def _get_random_pool_value_and_enqueue_input(): + dequeue_values = _to_tuple(pool_queue.dequeue()) + with ops.control_dependencies(dequeue_values): + enqueue_op = pool_queue.enqueue(input_values) + with ops.control_dependencies([enqueue_op]): + prob = random_ops.random_uniform( + (), dtype=dtypes.float32) < pooling_probability + return control_flow_ops.cond(prob, lambda: dequeue_values, + lambda: input_values) + + output_values = _to_tuple(control_flow_ops.cond( + pool_queue.size() < pool_size, _get_input_value_pooled, + _get_random_pool_value_and_enqueue_input)) + + if isinstance(original_input_values, list): + return list(output_values) + elif isinstance(original_input_values, tuple): + return output_values + return output_values[0] diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cef3a87ab34f9754099073eefcb3f1b1c97a3762 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py @@ -0,0 +1,110 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.gan.python.features.random_tensor_pool.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import tensor_pool +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class TensorPoolTest(test.TestCase): + + def test_pool_unknown_input_shape(self): + """Checks that `input_value` can have unknown shape.""" + input_value = array_ops.placeholder( + dtype=dtypes.int32, shape=[None, None, 3]) + output_value = tensor_pool(input_value, pool_size=10) + + with self.test_session(use_gpu=True) as session: + for i in range(10): + session.run(output_value, {input_value: [[[i] * 3]]}) + session.run(output_value, {input_value: [[[i] * 3] * 2]}) + session.run(output_value, {input_value: [[[i] * 3] * 5] * 2}) + + def test_pool_sequence(self): + """Checks that values are pooled and returned maximally twice.""" + input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) + output_value = tensor_pool(input_value, pool_size=10) + + with self.test_session(use_gpu=True) as session: + outs = [] + for i in range(50): + out = session.run(output_value, {input_value: i}) + outs.append(out) + self.assertLessEqual(out, i) + + _, counts = np.unique(outs, return_counts=True) + # Check that each value is returned maximally twice. + self.assertTrue((counts <= 2).all()) + + def test_never_pool(self): + """Checks that setting `pooling_probability` to zero works.""" + input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) + output_value = tensor_pool( + input_value, pool_size=10, pooling_probability=0.0) + + with self.test_session(use_gpu=True) as session: + for i in range(50): + out = session.run(output_value, {input_value: i}) + self.assertEqual(out, i) + + def test_pooling_probability(self): + """Checks that `pooling_probability` works.""" + input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) + pool_size = 10 + pooling_probability = 0.2 + output_value = tensor_pool( + input_value, + pool_size=pool_size, + pooling_probability=pooling_probability) + + with self.test_session(use_gpu=True) as session: + not_pooled = 0 + total = 1000 + for i in range(total): + out = session.run(output_value, {input_value: i}) + if out == i: + not_pooled += 1 + self.assertAllClose( + (not_pooled - pool_size) / (total - pool_size), + 1 - pooling_probability, + atol=0.03) + + def test_input_values_tuple(self): + """Checks that `input_values` can be a tuple.""" + input_values = (array_ops.placeholder(dtype=dtypes.int32, shape=[]), + array_ops.placeholder(dtype=dtypes.int32, shape=[])) + output_values = tensor_pool(input_values, pool_size=3) + self.assertEqual(len(output_values), len(input_values)) + + with self.test_session(use_gpu=True) as session: + for i in range(10): + outs = session.run(output_values, { + input_values[0]: i, + input_values[1]: i + 1 + }) + self.assertEqual(len(outs), len(input_values)) + self.assertEqual(outs[1] - outs[0], 1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/losses/__init__.py b/tensorflow/contrib/gan/python/losses/__init__.py index 290ff867a1e443f20a63e27fd97f53fed8a6cc11..d9bf8ebfdf65dfc76e4569dcaf26e0e51c7fc107 100644 --- a/tensorflow/contrib/gan/python/losses/__init__.py +++ b/tensorflow/contrib/gan/python/losses/__init__.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN losses and penalties. + +Losses can be used with individual arguments or with GANModel tuples. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 48f5e8e47dbcd5d32c23806b967a0d1e7403d2f7..3d4e315ebd0bd52b3b5e3e4a8655df8bfe9cebe8 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -79,6 +79,7 @@ class InfoGANModel( collections.namedtuple('InfoGANModel', GANModel._fields + ( 'structured_generator_inputs', 'predicted_distributions', + 'discriminator_and_aux_fn', ))): """An InfoGANModel contains all the pieces needed for InfoGAN training. @@ -91,6 +92,8 @@ class InfoGANModel( predicted_distributions: A list of tf.Distributions. Predicted by the recognizer, and used to evaluate the likelihood of the structured noise. List length should match `structured_generator_inputs`. + discriminator_and_aux_fn: The original discriminator function that returns + a tuple of (logits, `predicted_distributions`). """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index ad2d5eb86cdab89273efbd4ddce45f6657b54406..edd0113977ff4ddc672b0ec134be1a48c621b579 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -215,7 +215,8 @@ def infogan_model( disc_scope, lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API structured_generator_inputs, - predicted_distributions) + predicted_distributions, + discriminator_fn) def acgan_model( @@ -326,6 +327,53 @@ def _use_aux_loss(aux_loss_weight): return False +def _tensor_pool_adjusted_model(model, tensor_pool_fn): + """Adjusts model using `tensor_pool_fn`. + + Args: + model: A GANModel tuple. + tensor_pool_fn: A function that takes (generated_data, generator_inputs), + stores them in an internal pool and returns a previously stored + (generated_data, generator_inputs) with some probability. For example + tfgan.features.tensor_pool. + + Returns: + A new GANModel tuple where discriminator outputs are adjusted by taking + pooled generator outputs as inputs. Returns the original model if + `tensor_pool_fn` is None. + + Raises: + ValueError: If tensor pool does not suport the `model`. + """ + if tensor_pool_fn is None: + return model + + pooled_generated_data, pooled_generator_inputs = tensor_pool_fn( + (model.generated_data, model.generator_inputs)) + + if isinstance(model, namedtuples.GANModel): + dis_gen_outputs = model.discriminator_fn(pooled_generated_data, + pooled_generator_inputs) + return model._replace(discriminator_gen_outputs=dis_gen_outputs) + elif isinstance(model, namedtuples.ACGANModel): + (dis_pooled_gen_outputs, + dis_pooled_gen_classification_logits) = model.discriminator_fn( + pooled_generated_data, pooled_generator_inputs) + return model._replace( + discriminator_gen_outputs=dis_pooled_gen_outputs, + discriminator_gen_classification_logits= + dis_pooled_gen_classification_logits) + elif isinstance(model, namedtuples.InfoGANModel): + (dis_pooled_gen_outputs, + pooled_predicted_distributions) = model.discriminator_and_aux_fn( + pooled_generated_data, pooled_generator_inputs) + return model._replace( + discriminator_gen_outputs=dis_pooled_gen_outputs, + predicted_distributions=pooled_predicted_distributions) + else: + raise ValueError('Tensor pool does not support `model`: %s.' % type(model)) + + def gan_loss( # GANModel. model, @@ -338,6 +386,7 @@ def gan_loss( mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, + tensor_pool_fn=None, # Options. add_summaries=True): """Returns losses necessary to train generator and discriminator. @@ -363,6 +412,10 @@ def gan_loss( https://arxiv.org/abs/1610.09585 aux_cond_discriminator_weight: If not None: add a classification loss as in https://arxiv.org/abs/1610.09585 + tensor_pool_fn: A function that takes (generated_data, generator_inputs), + stores them in an internal pool and returns previous stored + (generated_data, generator_inputs). For example + `tf.gan.features.tensor_pool`. Defaults to None (not using tensor pool). add_summaries: Whether or not to add summaries for the losses. Returns: @@ -402,7 +455,9 @@ def gan_loss( # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) - dis_loss = discriminator_loss_fn(model, add_summaries=add_summaries) + dis_loss = discriminator_loss_fn( + _tensor_pool_adjusted_model(model, tensor_pool_fn), + add_summaries=add_summaries) # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): @@ -422,7 +477,7 @@ def gan_loss( ac_disc_loss = tfgan_losses.acgan_discriminator_loss( model, add_summaries=add_summaries) dis_loss += aux_cond_discriminator_weight * ac_disc_loss - # Gathers auxilliary losses. + # Gathers auxiliary losses. if model.generator_scope: gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name) else: diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 6b27b6926102b6e5a7ff134ceed75c23459a6534..519d101e07f4f28d684017b86102fce8fa7677ef 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python import train +from tensorflow.contrib.gan.python.features.python import random_tensor_pool from tensorflow.contrib.slim.python.slim import learning as slim_learning from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -145,14 +146,16 @@ def get_infogan_model(): return namedtuples.InfoGANModel( *get_gan_model(), structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])]) + predicted_distributions=[categorical.Categorical([1.0])], + discriminator_and_aux_fn=infogan_discriminator_model) def get_callable_infogan_model(): return namedtuples.InfoGANModel( *get_callable_gan_model(), structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])]) + predicted_distributions=[categorical.Categorical([1.0])], + discriminator_and_aux_fn=infogan_discriminator_model) def create_infogan_model(): @@ -409,6 +412,51 @@ class GANLossTest(test.TestCase): def test_callable_acgan(self): self._test_acgan_helper(create_callable_acgan_model) + # Test tensor pool. + def _test_tensor_pool_helper(self, create_gan_model_fn): + model = create_gan_model_fn() + if isinstance(model, namedtuples.InfoGANModel): + + def tensor_pool_fn_impl(input_values): + generated_data, generator_inputs = input_values + output_values = random_tensor_pool.tensor_pool( + [generated_data] + generator_inputs, pool_size=5) + return output_values[0], output_values[1:] + + tensor_pool_fn = tensor_pool_fn_impl + else: + + def tensor_pool_fn_impl(input_values): + return random_tensor_pool.tensor_pool(input_values, pool_size=5) + + tensor_pool_fn = tensor_pool_fn_impl + loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) + self.assertTrue(isinstance(loss, namedtuples.GANLoss)) + + # Check values. + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + for _ in range(10): + sess.run([loss.generator_loss, loss.discriminator_loss]) + + def test_tensor_pool_gan(self): + self._test_tensor_pool_helper(create_gan_model) + + def test_tensor_pool_callable_gan(self): + self._test_tensor_pool_helper(create_callable_gan_model) + + def test_tensor_pool_infogan(self): + self._test_tensor_pool_helper(create_infogan_model) + + def test_tensor_pool_callable_infogan(self): + self._test_tensor_pool_helper(create_callable_infogan_model) + + def test_tensor_pool_acgan(self): + self._test_tensor_pool_helper(create_acgan_model) + + def test_tensor_pool_callable_acgan(self): + self._test_tensor_pool_helper(create_callable_acgan_model) + def test_doesnt_crash_when_in_nested_scope(self): with variable_scope.variable_scope('outer_scope'): gan_model = train.gan_model( diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index 157e97d237021d95c935a6be66aa57842b97125c..54502cfc6eecb9d064ffde9773e97d893a24133a 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -9,6 +9,7 @@ package(default_visibility = ["//visibility:public"]) load( "//tensorflow:tensorflow.bzl", + "tf_cc_test", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", @@ -106,10 +107,33 @@ tf_custom_op_library( name = "python/ops/_distort_image_ops.so", srcs = [ "kernels/adjust_hsv_in_yiq_op.cc", + "kernels/adjust_hsv_in_yiq_op.h", "ops/distort_image_ops.cc", ], + gpu_srcs = [ + "kernels/adjust_hsv_in_yiq_op_gpu.cu.cc", + "kernels/adjust_hsv_in_yiq_op.h", + ], deps = [ - "@protobuf_archive//:protobuf", + "//tensorflow/core/kernels:gpu_util_hdrs", + ], +) + +tf_cc_test( + name = "adjust_hsv_in_yiq_op_test", + size = "small", + srcs = [ + "kernels/adjust_hsv_in_yiq_op.h", + "kernels/adjust_hsv_in_yiq_op_test.cc", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + "//third_party/eigen3", ], ) @@ -122,19 +146,6 @@ tf_gen_op_wrapper_py( deps = [":distort_image_ops_op_lib"], ) -cc_library( - name = "distort_image_ops_cc", - srcs = [ - "kernels/adjust_hsv_in_yiq_op.cc", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//third_party/eigen3", - ], - alwayslink = 1, -) - py_library( name = "distort_image_py", srcs = [ diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc index f4962ed69dc68d4bad06ef29d7a167e0ba8ae044..478b716d88321101c971789f36c0ff8ecd3f418e 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc @@ -12,14 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/work_sharder.h" @@ -36,10 +37,10 @@ class AdjustHsvInYiqOpBase : public OpKernel { struct ComputeOptions { const Tensor* input = nullptr; + Tensor* output = nullptr; const Tensor* delta_h = nullptr; const Tensor* scale_s = nullptr; const Tensor* scale_v = nullptr; - Tensor* output = nullptr; int64 channel_count = 0; }; @@ -65,7 +66,7 @@ class AdjustHsvInYiqOpBase : public OpKernel { scale_v.shape().DebugString())); auto channels = input.dim_size(input.dims() - 1); OP_REQUIRES( - context, channels == 3, + context, channels == kChannelSize, errors::InvalidArgument("input must have 3 channels but instead has ", channels, " channels.")); @@ -101,53 +102,21 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { const Tensor* input = options.input; Tensor* output = options.output; const int64 channel_count = options.channel_count; - static const int kChannelSize = 3; auto input_data = input->shaped({channel_count, kChannelSize}); const float delta_h = options.delta_h->scalar()(); const float scale_s = options.scale_s->scalar()(); const float scale_v = options.scale_v->scalar()(); auto output_data = output->shaped({channel_count, kChannelSize}); + float tranformation_matrix[kChannelSize * kChannelSize] = {0}; + internal::compute_tranformation_matrix( + delta_h, scale_s, scale_v, tranformation_matrix); const int kCostPerChannel = 10; const DeviceBase::CpuWorkerThreads& worker_threads = *context->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, channel_count, kCostPerChannel, - [channel_count, &input_data, &output_data, delta_h, scale_s, scale_v]( + [channel_count, &input_data, &output_data, &tranformation_matrix]( int64 start_channel, int64 end_channel) { - // Using approximate linear transfomation described in: - // https://beesbuzz.biz/code/hsv_color_transforms.php - /** Get the constants from sympy - from sympy import Matrix - from sympy.abc import u, w - # Projection matrix to YIQ. http://en.wikipedia.org/wiki/YIQ - tyiq = Matrix([[0.299, 0.587, 0.114], - [0.596, -0.274, -0.322], - [0.211, -0.523, 0.312]]) - # Hue rotation matrix in YIQ space. - hue_proj = Matrix(3,3, [v, 0, 0, 0, vsu, -vsw, 0, vsw, vsu]) - m = tyiq.inv() * hue_proj * tyiq - **/ - // TODO(huangyp): directly compute the projection matrix from tyiq. - static const float t[kChannelSize][kChannelSize][kChannelSize] = { - {{.299, .701, .16862179492229}, - {.587, -.587, .329804745287403}, - {.114, -.114, -0.498426540209694}}, - {{.299, -.299, -.327963394172371}, - {.587, .413, .0346106879248821}, - {.114, -.114, .293352706247489}}, - {{.299, -.299, 1.24646136576682}, - {.587, -.587, -1.04322888291964}, - {.114, .886, -.203232482847173}}}; - float m[kChannelSize][kChannelSize] = {{0.}}; - float su = scale_s * std::cos(delta_h); - float sw = scale_s * std::sin(delta_h); - for (int q_index = 0; q_index < kChannelSize; q_index++) { - for (int p_index = 0; p_index < kChannelSize; p_index++) { - m[q_index][p_index] = scale_v * (t[q_index][p_index][0] + - t[q_index][p_index][1] * su + - t[q_index][p_index][2] * sw); - } - } // Applying projection matrix to input RGB vectors. const float* p = input_data.data() + start_channel * kChannelSize; float* q = output_data.data() + start_channel * kChannelSize; @@ -155,7 +124,9 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { for (int q_index = 0; q_index < kChannelSize; q_index++) { q[q_index] = 0; for (int p_index = 0; p_index < kChannelSize; p_index++) { - q[q_index] += m[q_index][p_index] * p[p_index]; + q[q_index] += + p[p_index] * + tranformation_matrix[q_index + kChannelSize * p_index]; } } p += kChannelSize; @@ -165,8 +136,33 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { } }; -REGISTER_KERNEL_BUILDER(Name("AdjustHsvInYiq").Device(DEVICE_CPU), - AdjustHsvInYiqOp); +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint("T"), + AdjustHsvInYiqOp); + +#if GOOGLE_CUDA +template <> +class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { + public: + explicit AdjustHsvInYiqOp(OpKernelConstruction* context) + : AdjustHsvInYiqOpBase(context) {} + + void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override { + const int64 number_of_elements = options.input->NumElements(); + if (number_of_elements <= 0) { + return; + } + const float* delta_h = options.delta_h->flat().data(); + const float* scale_s = options.scale_s->flat().data(); + const float* scale_v = options.scale_v->flat().data(); + functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input, + delta_h, scale_s, scale_v, options.output); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint("T"), + AdjustHsvInYiqOp); +#endif -// TODO(huangyp): add the GPU kernel } // namespace tensorflow diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h new file mode 100644 index 0000000000000000000000000000000000000000..194ae2ba47456cac66c01989a78ab4ce607d1295 --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +static constexpr int kChannelSize = 3; + +namespace internal { + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void compute_tranformation_matrix( + const float delta_h, const float scale_s, const float scale_v, + float* matrix) { + static_assert(MATRIX_SIZE == kChannelSize * kChannelSize, + "Size of matrix should be 9."); + // Projection matrix from RGB to YIQ. Numbers from wikipedia + // https://en.wikipedia.org/wiki/YIQ + Eigen::Matrix3f yiq; + /* clang-format off */ + yiq << 0.299, 0.587, 0.114, + 0.596, -0.274, -0.322, + 0.211, -0.523, 0.312; + Eigen::Matrix3f yiq_inverse; + yiq_inverse << 1, 0.95617069, 0.62143257, + 1, -0.2726886, -0.64681324, + 1, -1.103744, 1.70062309; + /* clang-format on */ + // Construct hsv linear transformation matrix in YIQ space. + // https://beesbuzz.biz/code/hsv_color_transforms.php + float vsu = scale_v * scale_s * std::cos(delta_h); + float vsw = scale_v * scale_s * std::sin(delta_h); + Eigen::Matrix3f hsv_transform; + /* clang-format off */ + hsv_transform << scale_v, 0, 0, + 0, vsu, -vsw, + 0, vsw, vsu; + /* clang-format on */ + // Compute final transformation matrix = inverse_yiq * hsv_transform * yiq + Eigen::Map> eigen_matrix(matrix); + eigen_matrix = yiq_inverse * hsv_transform * yiq; +} +} // namespace internal + +#if GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +struct AdjustHsvInYiqGPU { + void operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, const float* const delta_h, + const float* const scale_s, const float* const scale_v, + Tensor* const output); +}; + +} // namespace functor + +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..b71ff9cd507faac66b3a33d3c02ec9b5901d814a --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" +#include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +namespace internal { + +__global__ void compute_tranformation_matrix_cuda(const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + float* const matrix, + const int matrix_size) { + if (matrix_size == kChannelSize * kChannelSize) { + compute_tranformation_matrix( + *delta_h, *scale_s, *scale_v, matrix); + } +} +} // namespace internal + +namespace functor { + +void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, + const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + Tensor* const output) { + const uint64 m = channel_count; + const uint64 k = kChannelSize; + const uint64 n = kChannelSize; + auto* cu_stream = ctx->eigen_device().stream(); + OP_REQUIRES(ctx, cu_stream, errors::Internal("No GPU stream available.")); + Tensor tranformation_matrix; + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DT_FLOAT, TensorShape({kChannelSize * kChannelSize}), + &tranformation_matrix)); + // TODO(huangyp): It takes about 3.5 us to comute tranformation_matrix + // with one thread. Improve its performance if necessary. + internal::compute_tranformation_matrix_cuda<<<1, 1, 0, cu_stream>>>( + delta_h, scale_s, scale_v, tranformation_matrix.flat().data(), + tranformation_matrix.flat().size()); + // Call cuBlas C = A * B directly. + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + auto a_ptr = + AsDeviceMemory(input->flat().data(), input->flat().size()); + auto b_ptr = AsDeviceMemory(tranformation_matrix.flat().data(), + tranformation_matrix.flat().size()); + auto c_ptr = AsDeviceMemory(output->flat().data(), + output->flat().size()); + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + // TODO(huangyp): share/use autotune cublas algorithms in Matmul.op. + bool blas_launch_status = + stream + ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } +} +} // namespace functor +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4cbbd277840133c9419f9ce3d945b7d099679dc0 --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class AdjustHsvInYiqOpTest : public OpsTestBase { + protected: +}; + +TEST_F(AdjustHsvInYiqOpTest, IdentiyTransformMatrix) { + Tensor matrix(allocator(), DT_FLOAT, TensorShape({9})); + internal::compute_tranformation_matrix<9>(0.0, 1.0, 1.0, + matrix.flat().data()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({9})); + test::FillValues(&expected, {1, 0, 0, 0, 1, 0, 0, 0, 1}); + test::ExpectClose(matrix, expected); +} + +TEST_F(AdjustHsvInYiqOpTest, ScaleValueTransformMatrix) { + float scale_v = 2.3; + Tensor matrix(allocator(), DT_FLOAT, TensorShape({9})); + internal::compute_tranformation_matrix<9>(0.0, 1.0, scale_v, + matrix.flat().data()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({9})); + test::FillValues(&expected, + {scale_v, 0, 0, 0, scale_v, 0, 0, 0, scale_v}); + test::ExpectClose(matrix, expected); +} + +} // end namespace tensorflow diff --git a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc index 2b6799213827537f77deda4e052bb7ec16f46343..f8b56ab1c5400694b3aa8d4a0c19c7769aa8cbce 100755 --- a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc @@ -40,7 +40,7 @@ REGISTER_OP("SingleImageRandomDotStereograms") .Doc(R"doc( Outputs a single image random dot stereogram for export via encode_PNG/JPG OP. -Given the 2-D tensor 'depth_values' with encoded Z values, this operation will +Given the 2-D tensor 'depth_values' with encoded Z values, this operation will encode 3-D data into a 2-D image. The output of this Op is suitable for the encode_PNG/JPG ops. Be careful with image compression as this may corrupt the encode 3-D data witin the image. @@ -68,14 +68,14 @@ with open('picture_out.png', 'wb') as f: f.write(png) ``` -depth_values: Z values of data to encode into 'output_data_window' window, +depth_values: Z values of data to encode into 'output_data_window' window, lower values are further away {0.0 floor(far), 1.0 ceiling(near) after normalization}, must be 2-D tensor hidden_surface_removal: Activate hidden surface removal convergence_dots_size: Black dot size in pixels to help view converge image, drawn on bottom of image dots_per_inch: Output device in dots/inch eye_separation: Separation between eyes in inches mu: Depth of field, Fraction of viewing distance (eg. 1/3 = .3333) -normalize: Normalize input data to [0.0, 1.0] +normalize: Normalize input data to [0.0, 1.0] normalize_max: Fix MAX value for Normalization - if < MIN, autoscale normalize_min: Fix MIN value for Normalization - if > MAX, autoscale border_level: Value of border depth 0.0 {far} to 1.0 {near} diff --git a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py index b85f19d29b79defa10493bdbaa4a1b237cb2a9ee..a495b58b7f6481d4cdedf73f23615d0390eb6a45 100644 --- a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py @@ -172,7 +172,7 @@ class AdjustValueInYiqTest(test_util.TensorFlowTestCase): raise AssertionError('Invalid test style: %s' % (test_style)) y_np = self._adjust_value_in_yiq_np(x_np, scale) y_tf = self._adjust_value_in_yiq_tf(x_np, scale) - self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5) + self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4) def test_invalid_shapes(self): x_np = np.random.rand(2, 3) * 255. @@ -237,7 +237,7 @@ class AdjustSaturationInYiqTest(test_util.TensorFlowTestCase): raise AssertionError('Invalid test style: %s' % (test_style)) y_baseline = self._adjust_saturation_in_yiq_np(x_np, scale) y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale) - self.assertAllClose(y_tf, y_baseline, rtol=2e-5, atol=1e-5) + self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4) def test_invalid_shapes(self): x_np = np.random.rand(2, 3) * 255. @@ -291,6 +291,9 @@ class AdjustHueInYiqBenchmark(test.Benchmark): def benchmark_adjust_hue_in_yiqCpuAll(self): self._benchmark_adjust_hue_in_yiq('/cpu:0', None) + def benchmark_adjust_hue_in_yiq_gpu_all(self): + self._benchmark_adjust_hue_in_yiq(test.gpu_device_name(), None) + class AdjustSaturationInYiqBenchmark(test.Benchmark): @@ -333,6 +336,9 @@ class AdjustSaturationInYiqBenchmark(test.Benchmark): def benchmark_adjust_saturation_in_yiq_cpu_all(self): self._benchmark_adjust_saturation_in_yiq('/cpu:0', None) + def benchmark_adjust_saturation_in_yiq_gpu_all(self): + self._benchmark_adjust_saturation_in_yiq(test.gpu_device_name(), None) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index 011ddeaa9a1eebaa507c9e0d33f9546ff3497166..faedee6f87772016561671bacd87f88657eafffb 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -224,7 +224,8 @@ def transform(images, transforms, interpolation="NEAREST", name=None): `(x, y)` to a transformed *input* point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to - the transform mapping input points to output points. + the transform mapping input points to output points. Note that gradients + are not backpropagated into transformation parameters. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". Returns: diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index 5cccf26028ca6bf269dbc67a33075351edecb407..bb766e59d2cee648042cc08be466796d9233ad66 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -68,7 +68,7 @@ def single_image_random_dot_stereograms( ``` Args: - depth_values: A `Tensor`. Must be one of the following types: + depth_values: A `Tensor`. Must be one of the following types: `float64`, `float32`, `int64`, `int32`. Z values of data to encode into 'output_data_window' window, lower further away {0.0 floor(far), 1.0 ceiling(near) after norm}, must be 2-D tensor @@ -84,17 +84,17 @@ def single_image_random_dot_stereograms( mu: An optional `float`. Defaults to `0.3333`. Depth of field, Fraction of viewing distance (eg. 1/3 = 0.3333) normalize: An optional `bool`. Defaults to `True`. - Normalize input data to [0.0, 1.0] + Normalize input data to [0.0, 1.0] normalize_max: An optional `float`. Defaults to `-100`. Fix MAX value for Normalization (0.0) - if < MIN, autoscale normalize_min: An optional `float`. Defaults to `100`. Fix MIN value for Normalization (0.0) - if > MAX, autoscale border_level: An optional `float`. Defaults to `0`. - Value of bord in depth 0.0 {far} to 1.0 {near} + Value of bord in depth 0.0 {far} to 1.0 {near} number_colors: An optional `int`. Defaults to `256`. 2 (Black & White), 256 (grayscale), and Numbers > 256 (Full Color) are supported - output_image_shape: An optional `tf.TensorShape` or list of `ints`. + output_image_shape: An optional `tf.TensorShape` or list of `ints`. Defaults to shape `[1024, 768, 1]`. Defines output shape of returned image in '[X,Y, Channels]' 1-grayscale, 3 color; channels will be updated to 3 if number_colors > 256 diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 5d86373a232d55cd281d06cfc0606f4224d8f669..95fba59e3c96ae3c69e0b154740785b0d2bcb3c9 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -16,6 +16,7 @@ py_test( "//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", @@ -33,6 +34,7 @@ py_test( "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", @@ -68,6 +70,7 @@ py_test( srcs = ["layer_collection_test.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/contrib/kfac/python/ops:layer_collection", "//tensorflow/python:array_ops", @@ -75,6 +78,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:variable_scope", @@ -88,7 +92,6 @@ py_test( deps = [ "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/contrib/kfac/python/ops:loss_functions", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", @@ -139,6 +142,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", + "//tensorflow/python:random_ops", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index b52a7b52a7efd4292ad514c5a744c4da07082142..9b28c45c7263208d21b1514ae5f05b7e81e315a3 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -33,6 +34,30 @@ from tensorflow.python.platform import test _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] +class DeviceContextGeneratorTest(test.TestCase): + + def testNoDevice(self): + device_context_generator = estimator._DeviceContextGenerator(None) + with ops.device("/device:CPU:0"): # This is what will be used + with device_context_generator(): # Does nothing + a = constant_op.constant([2.0], name="a") + self.assertEqual("/device:CPU:0", a.op.device) + + def testTwoDevices(self): + device_context_generator = estimator._DeviceContextGenerator( + ["/device:GPU:0", "/device:GPU:1"]) + with ops.device("/device:CPU:0"): # Will be over-ridden by the inner scopes + with device_context_generator(): + a = constant_op.constant([2.0], name="a") + with device_context_generator(): + b = constant_op.constant([2.0], name="b") + with device_context_generator(): + c = constant_op.constant([2.0], name="c") + self.assertEqual("/device:GPU:0", a.op.device) + self.assertEqual("/device:GPU:1", b.op.device) + self.assertEqual("/device:GPU:0", c.op.device) + + class EstimatorTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 5f2b5c6cace9cd18f4cc5590ff55a9b39680a381..2d9b28185ce0db32d5cd7d84737fdf96e2c98851 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -40,6 +40,21 @@ def _make_psd(dim): return array_ops.constant(mat) +class UtilsTest(test.TestCase): + + def testComputePiTracenorm(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + left_factor = array_ops.diag([1., 2., 0., 1.]) + right_factor = array_ops.ones([2., 2.]) + + # pi is the sqrt of the left trace norm divided by the right trace norm + pi = fb._compute_pi_tracenorm(left_factor, right_factor) + + pi_val = sess.run(pi) + self.assertEqual(1., pi_val) + + class FullFBTest(test.TestCase): def testFullFBInitSingleTensor(self): @@ -301,8 +316,7 @@ class FullyConnectedDiagonalFB(test.TestCase): multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, - np.split(self.inputs, 2), + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), np.split(self.outputs, 2), np.split(self.output_grads, 2))) @@ -584,8 +598,7 @@ class ConvDiagonalFBTest(test.TestCase): multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, - np.split(self.inputs, 2), + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), np.split(self.outputs, 2), np.split(self.output_grads, 2))) @@ -608,8 +621,9 @@ class ConvDiagonalFBTest(test.TestCase): self.kernel_size, self.kernel_size, self.input_channels + 1, self.output_channels ]) - expected_result = (expected_result[:, :, 0:-1, :], np.reshape( - expected_result[:, :, -1, :], [self.output_channels])) + expected_result = (expected_result[:, :, 0:-1, :], + np.reshape(expected_result[:, :, -1, :], + [self.output_channels])) self.assertEqual(len(result), 2) self.assertAllClose(expected_result[0], result[0]) @@ -692,8 +706,8 @@ class ConvKFCBasicFBTest(test.TestCase): sess.run(block._input_factor.make_inverse_update_ops()) sess.run(block._output_factor.make_inverse_update_ops()) - vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), np.arange( - 2, 4).reshape(2, 1).astype(np.float32)) + vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), + np.arange(2, 4).reshape(2, 1).astype(np.float32)) output = block.multiply_inverse((array_ops.constant(vector[0]), array_ops.constant(vector[1]))) @@ -776,11 +790,50 @@ class ConvKFCBasicFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) +class FullyConnectedSeriesFBTest(test.TestCase): + + def testFullyConnectedSeriesFBInit(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([1., 2.]) + outputs = array_ops.constant([3., 4.]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), inputs=[inputs], outputs=[outputs]) + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + + def testInstantiateFactorsHasBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + inputs=[inputs], + outputs=[outputs], + has_bias=True) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def testInstantiateFactorsNoBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + inputs=[inputs], + outputs=[outputs], + has_bias=False) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def as_tensors(tensor_or_tuple): """Converts a potentially nested tuple of np.array to Tensors.""" if isinstance(tensor_or_tuple, (tuple, list)): return tuple(as_tensors(t) for t in tensor_or_tuple) return ops.convert_to_tensor(tensor_or_tuple) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index fbb3d219139a4bc05253841a89e73645ef37dddd..70e56db055078bd4399b03e4d4a877e34249cc5e 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -22,6 +22,7 @@ import numpy as np import numpy.random as npr from tensorflow.contrib.kfac.python.ops import fisher_factors as ff +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops from tensorflow.python.framework import random_seed @@ -32,6 +33,25 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test +class MaybeColocateTest(test.TestCase): + + def testFalse(self): + with tf_ops.Graph().as_default(): + a = constant_op.constant([2.0], name='a') + with ff._maybe_colocate_with(a, False): + b = constant_op.constant(3.0, name='b') + self.assertEqual([b'loc:@a'], a.op.colocation_groups()) + self.assertEqual([b'loc:@b'], b.op.colocation_groups()) + + def testTrue(self): + with tf_ops.Graph().as_default(): + a = constant_op.constant([2.0], name='a') + with ff._maybe_colocate_with(a, True): + b = constant_op.constant(3.0, name='b') + self.assertEqual([b'loc:@a'], a.op.colocation_groups()) + self.assertEqual([b'loc:@a'], b.op.colocation_groups()) + + class FisherFactorTestingDummy(ff.FisherFactor): """Dummy class to test the non-abstract methods on ff.FisherFactor.""" @@ -47,12 +67,19 @@ class FisherFactorTestingDummy(ff.FisherFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return dtypes.float32 + def _compute_new_cov(self): raise NotImplementedError def instantiate_covariance(self): pass + def make_inverse_update_ops(self): + return [] + class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. @@ -74,6 +101,10 @@ class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return dtypes.float32 + def _compute_new_cov(self): raise NotImplementedError @@ -101,7 +132,7 @@ class NumericalUtilsTest(test.TestCase): normalizer = 10. x = npr.randn(100, 3) - cov = ff._compute_cov(array_ops.constant(x), normalizer) + cov = ff._compute_cov(array_ops.constant(x), normalizer=normalizer) np_cov = np.dot(x.T, x) / normalizer self.assertAllClose(sess.run(cov), np_cov) @@ -247,13 +278,13 @@ class InverseProvidingFactorTest(test.TestCase): for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): factor.register_damped_inverse(1. / i) ops = factor.make_inverse_update_ops() - self.assertEqual(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD, len(ops)) + self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) new_invs = [] + sess.run(ops) for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): # The inverse op will assign the damped inverse of cov to the inv var. - sess.run(ops[i - 1]) new_invs.append(sess.run(factor._inverses_by_damping[1. / i])) # We want to see that the new invs are all different from each other. for i in range(len(new_invs)): @@ -311,6 +342,16 @@ class FullFactorTest(test.TestCase): factor = ff.FullFactor((tensor,), 32) self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) + def testFullFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.FullFactor((tensor,), 32) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 6], cov.get_shape().as_list()) + def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -331,6 +372,16 @@ class NaiveDiagonalFactorTest(test.TestCase): factor = ff.NaiveDiagonalFactor((tensor,), 32) self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + def testNaiveDiagonalFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 32) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 1], cov.get_shape().as_list()) + def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -344,18 +395,25 @@ class NaiveDiagonalFactorTest(test.TestCase): class FullyConnectedKroneckerFactorTest(test.TestCase): - def _testFullyConnectedKroneckerFactorInit(self, has_bias, final_shape): + def _testFullyConnectedKroneckerFactorInit(self, + has_bias, + final_shape, + dtype=dtypes.float32_ref): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias) - self.assertEqual(final_shape, factor.get_cov().get_shape().as_list()) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual(final_shape, cov.get_shape().as_list()) def testFullyConnectedKroneckerFactorInitNoBias(self): - self._testFullyConnectedKroneckerFactorInit(False, [3, 3]) + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype) def testFullyConnectedKroneckerFactorInitWithBias(self): - self._testFullyConnectedKroneckerFactorInit(True, [4, 4]) + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype) def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: @@ -398,6 +456,18 @@ class ConvInputKroneckerFactorTest(test.TestCase): self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list()) + def testConvInputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.ConvInputKroneckerFactor( + tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], + cov.get_shape().as_list()) + def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -433,6 +503,16 @@ class ConvOutputKroneckerFactorTest(test.TestCase): factor = ff.ConvOutputKroneckerFactor((tensor,)) self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) + def testConvOutputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') + factor = ff.ConvOutputKroneckerFactor((tensor,)) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([5, 5], cov.get_shape().as_list()) + def testConvOutputKroneckerFactorInitNotEnoughDims(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) @@ -451,5 +531,49 @@ class ConvOutputKroneckerFactorTest(test.TestCase): self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov) +class FullyConnectedMultiKFTest(test.TestCase): + + def testFullyConnectedMultiKFInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) + + def testFullyConnectedMultiKFInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([3, 3], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,)) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index 524e8338fde9bb20586b15c33ba2055e852baa01..b8ccbeadd0a9d69edb41fef50e3edb090457adf2 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.kfac.python.ops import fisher_blocks from tensorflow.contrib.kfac.python.ops import fisher_factors from tensorflow.contrib.kfac.python.ops import layer_collection from tensorflow.python.framework import dtypes @@ -25,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -105,8 +107,10 @@ class LayerCollectionTest(test.TestCase): array_ops.constant(4), [1, 1, 1, 1], 'SAME', array_ops.ones((1, 1, 1, 1)), array_ops.constant(3)) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], 'SAME', - array_ops.ones((1, 1, 1, 1)), array_ops.constant(3), + array_ops.constant(4), [1, 1, 1, 1], + 'SAME', + array_ops.ones((1, 1, 1, 1)), + array_ops.constant(3), approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_generic( array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) @@ -122,10 +126,11 @@ class LayerCollectionTest(test.TestCase): random_seed.set_random_seed(200) lc = layer_collection.LayerCollection() key = array_ops.constant(1) - lc.register_fully_connected(key, - array_ops.constant(2), array_ops.constant(3)) - with self.assertRaises(ValueError): + lc.register_fully_connected(key, array_ops.constant(2), + array_ops.constant(3)) + with self.assertRaises(ValueError) as cm: lc.register_generic(key, 16) + self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterSingleParamNotRegistered(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -140,16 +145,18 @@ class LayerCollectionTest(test.TestCase): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {x: '1'} - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: lc.register_block(x, 'foo') + self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterSingleParamRegisteredInTuple(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y): '1'} - lc.register_block(x, 'foo') - self.assertEqual(set(['1']), set(lc.get_blocks())) + with self.assertRaises(ValueError) as cm: + lc.register_block(x, 'foo') + self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleParamNotRegistered(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -169,8 +176,9 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y): '1'} - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: lc.register_block((x, y), 'foo') + self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterTupleParamRegisteredInSuperset(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -179,8 +187,9 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y, z): '1'} - lc.register_block((x, y), 'foo') - self.assertEqual(set(['1']), set(lc.get_blocks())) + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), 'foo') + self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleParamSomeRegistered(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -189,10 +198,9 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')} - lc.register_block((x, y), MockFisherBlock('foo')) - self.assertEqual( - set([MockFisherBlock('2'), MockFisherBlock('foo')]), - set(lc.get_blocks())) + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), MockFisherBlock('foo')) + self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleVarSomeRegisteredInOtherTuples(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -202,8 +210,9 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, z): '1', (z, w): '2'} - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: lc.register_block((x, y), 'foo') + self.assertIn('was already registered', str(cm.exception)) def testRegisterCategoricalPredictiveDistribution(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -423,6 +432,23 @@ class LayerCollectionTest(test.TestCase): self.ensureLayerReuseWorks(register_fn) + def testReuseWithInvalidRegistration(self): + """Invalid registrations shouldn't overwrite existing blocks.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 5, 5, 10]) + outputs = array_ops.zeros([2, 5, 5, 3]) + w = variable_scope.get_variable('w', [1, 1, 10, 3]) + b = variable_scope.get_variable('b', [3]) + lc = layer_collection.LayerCollection() + lc.register_fully_connected(w, inputs, outputs) + self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + with self.assertRaises(KeyError): + lc.register_fully_connected((w, b), inputs, outputs, reuse=True) + self.assertNotIn((w, b), lc.fisher_blocks) + self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + lc.register_fully_connected(w, inputs, outputs, reuse=True) + self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 2) + def testMakeOrGetFactor(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) @@ -464,6 +490,66 @@ class LayerCollectionTest(test.TestCase): use_count_map = lc.get_use_count_map() self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map) + def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + z = variable_scope.get_variable('z', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters((x, z)) + + def testIdentifySubsetPreviouslyRegisteredTensor(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters(x) + + def testSpecifyApproximation(self): + w_0 = variable_scope.get_variable('w_0', [10, 10]) + w_1 = variable_scope.get_variable('w_1', [10, 10]) + + b_0 = variable_scope.get_variable('b_0', [10]) + b_1 = variable_scope.get_variable('b_1', [10]) + + x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + + pre_bias_0 = math_ops.matmul(x_0, w_0) + pre_bias_1 = math_ops.matmul(x_1, w_1) + + # Build the fully connected layers in the graph. + pre_bias_0 + b_0 # pylint: disable=pointless-statement + pre_bias_1 + b_1 # pylint: disable=pointless-statement + + lc = layer_collection.LayerCollection() + lc.define_linked_parameters( + w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + b_0, approximation=layer_collection.APPROX_FULL_NAME) + lc.define_linked_parameters( + b_1, approximation=layer_collection.APPROX_FULL_NAME) + + lc.register_fully_connected(w_0, x_0, pre_bias_0) + lc.register_fully_connected( + w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME) + self.assertIsInstance(lc.fisher_blocks[w_0], + fisher_blocks.FullyConnectedDiagonalFB) + self.assertIsInstance(lc.fisher_blocks[w_1], + fisher_blocks.FullyConnectedKFACBasicFB) + + lc.register_generic(b_0, batch_size=1) + lc.register_generic( + b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME) + self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) + self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py index 87339cb059802ec8944d5d1ae4557ee34550cd60..39ce3e9337157c8206107bc40c489e44019743ab 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -24,6 +24,7 @@ from tensorflow.contrib.kfac.python.ops import loss_functions from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -96,6 +97,22 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) + def testMultiMinibatchRegistration(self): + """Ensure this loss function supports registering multiple minibatches.""" + with ops.Graph().as_default(): + tower_logits = [] + loss = None + num_towers = 5 + for _ in range(num_towers): + logits = random_ops.random_uniform(shape=[2, 3]) + tower_logits.append(logits) + if loss is None: + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + else: + loss.register_additional_minibatch(logits) + self.assertListEqual(loss.input_minibatches, tower_logits) + self.assertEqual(loss.num_registered_minibatches, num_towers) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index 55fe38e3e9aab2dbd70a45cdc8fa0c208b036db0..d255a6e7160386d8eb6fca00765eea8a318f4eaa 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -222,18 +222,6 @@ class UtilsTest(test.TestCase): self.assertAllClose(b, np.array([4., 5.])) self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]])) - def testComputePi(self): - with ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) - left_factor = array_ops.diag([1., 2., 0., 1.]) - right_factor = array_ops.ones([2., 2.]) - - # pi is the sqrt of the left trace norm divided by the right trace norm - pi = utils.compute_pi(left_factor, right_factor) - - pi_val = sess.run(pi) - self.assertEqual(1., pi_val) - def testPosDefInvCholesky(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index de4b8920b849dbf2117657de6e7c26f94f4d0363..3d731c7bc206d6f168e9b8f29b66bf4f1dbe8542 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -38,6 +38,7 @@ py_library( ":utils", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:special_math_ops", @@ -171,6 +172,7 @@ py_library( deps = [ ":utils", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", "//tensorflow/python:util", "//third_party/py/numpy", diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index ce4e776324bbde1b8f214d89daa876032d8a21ff..5e1680967c184bf19f2a2578219db07a48264dc9 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -18,16 +18,53 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math +import contextlib +import itertools import numpy as np from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.util import nest +class _DeviceContextGenerator(object): + """Class for generating device contexts in a round-robin fashion.""" + + def __init__(self, devices): + """Creates a _DeviceContextGenerator object. + + Example usage: + + ```python + dcg = _DeviceContextGenerator(['/gpu:0', 'gpu:1']) + with dcg(): + # All operations in this context will be placed on GPU 0 + ... + with dcg(): + # All operations in this context will be placed on GPU 1 + ... + ``` + + Args: + devices: An iterable of device strings (or None). Successive calls to + __call__ will give contexts which place devices on these devices in + a round-robin fashion. + """ + self._cycle = None if devices is None else itertools.cycle(devices) + + @contextlib.contextmanager + def __call__(self): + """Returns a context manager specifying the default device.""" + if self._cycle is None: + yield + else: + with tf_ops.device(next(self._cycle)): + yield + + class FisherEstimator(object): """Fisher estimator class supporting various approximations of the Fisher.""" @@ -36,7 +73,10 @@ class FisherEstimator(object): cov_ema_decay, damping, layer_collection, - estimation_mode="gradients"): + estimation_mode="gradients", + colocate_gradients_with_ops=False, + cov_devices=None, + inv_devices=None): """Create a FisherEstimator object. Args: @@ -54,7 +94,7 @@ class FisherEstimator(object): blocks, kronecker factors, and losses associated with the graph. estimation_mode: The type of estimator to use for the Fishers. Can be - 'gradients', 'empirical', 'curvature_propagation', or 'exact'. + 'gradients', 'empirical', 'curvature_prop', or 'exact'. (Default: 'gradients'). 'gradients' is the basic estimation approach from the original K-FAC paper. 'empirical' computes the 'empirical' Fisher information matrix (which uses the data's distribution for the @@ -69,6 +109,14 @@ class FisherEstimator(object): for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking. + colocate_gradients_with_ops: Whether we should request gradients be + colocated with their respective ops. + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. Raises: ValueError: If no losses have been registered with layer_collection. @@ -79,13 +127,19 @@ class FisherEstimator(object): self._estimation_mode = estimation_mode self._layers = layer_collection self._layers.create_subgraph() - self._check_registration(variables) + self._layers.check_registration(variables) self._gradient_fns = { "gradients": self._get_grads_lists_gradients, "empirical": self._get_grads_lists_empirical, "curvature_prop": self._get_grads_lists_curvature_prop, "exact": self._get_grads_lists_exact } + self._colocate_gradients_with_ops = colocate_gradients_with_ops + self._cov_device_context_generator = _DeviceContextGenerator(cov_devices) + if inv_devices == cov_devices: + self._inv_device_context_generator = self._cov_device_context_generator + else: + self._inv_device_context_generator = _DeviceContextGenerator(inv_devices) setup = self._setup(cov_ema_decay) self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup @@ -148,49 +202,6 @@ class FisherEstimator(object): return self._apply_transformation(vecs_and_vars, lambda fb, vec: fb.multiply(vec)) - def _check_registration(self, variables): - """Checks that all variable uses have been registered properly. - - Args: - variables: List of variables. - - Raises: - ValueError: If any registered variables are not included in the list. - ValueError: If any variable in the list is not registered. - ValueError: If any variable in the list is registered with the wrong - number of "uses" in the subgraph recorded (vs the number of times that - variable is actually used in the subgraph). - """ - # Note that overlapping parameters (i.e. those that share variables) will - # be caught by layer_collection.LayerParametersDict during registration. - - reg_use_map = self._layers.get_use_count_map() - - error_messages = [] - - for var in variables: - total_uses = self._layers.subgraph.variable_uses(var) - reg_uses = reg_use_map[var] - - if reg_uses == 0: - error_messages.append("Variable {} not registered.".format(var)) - elif (not math.isinf(reg_uses)) and reg_uses != total_uses: - error_messages.append( - "Variable {} registered with wrong number of uses ({} " - "registrations vs {} uses).".format(var, reg_uses, total_uses)) - - num_get_vars = len(reg_use_map) - - if num_get_vars > len(variables): - error_messages.append("{} registered variables were not included in list." - .format(num_get_vars - len(variables))) - - if error_messages: - error_messages = [ - "Found the following errors with variable registration:" - ] + error_messages - raise ValueError("\n\t".join(error_messages)) - def _setup(self, cov_ema_decay): """Sets up the various operations. @@ -219,8 +230,13 @@ class FisherEstimator(object): raise ValueError("Unrecognized value {} for estimation_mode.".format( self._estimation_mode)) + # TODO(b/68033310): This loop round-robins the "concat" operations which + # gather the inputs for the cov_updates. In future, we might do these + # computations locally then communicate the results, which would require a + # modification to this code. for grads_list, fb in zip(grads_lists, fisher_blocks_list): - fb.instantiate_factors(grads_list, self.damping) + with self._cov_device_context_generator(): + fb.instantiate_factors(grads_list, self.damping) cov_updates = [ factor.make_covariance_update_op(cov_ema_decay) @@ -233,18 +249,23 @@ class FisherEstimator(object): def _get_all_inverse_update_ops(self): for factor in self._layers.get_factors(): - for op in factor.make_inverse_update_ops(): - yield op + with self._inv_device_context_generator(): + for op in factor.make_inverse_update_ops(): + yield op def _get_grads_lists_gradients(self, tensors): - grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(), - nest.flatten(tensors)) + grads_flat = gradients_impl.gradients( + self._layers.total_sampled_loss(), + nest.flatten(tensors), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_grads_lists_empirical(self, tensors): - grads_flat = gradients_impl.gradients(self._layers.total_loss(), - nest.flatten(tensors)) + grads_flat = gradients_impl.gradients( + self._layers.total_loss(), + nest.flatten(tensors), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) @@ -262,11 +283,13 @@ class FisherEstimator(object): grads_flat = gradients_impl.gradients( nest.flatten(loss_inputs), nest.flatten(tensors), - grad_ys=nest.flatten(transformed_random_signs)) + grad_ys=nest.flatten(transformed_random_signs), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_grads_lists_exact(self, tensors): + """No docstring required.""" # Loop over all coordinates of all losses. grads_all = [] for loss in self._layers.losses: @@ -274,6 +297,9 @@ class FisherEstimator(object): transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( index) grads_flat = gradients_impl.gradients( - loss.inputs, nest.flatten(tensors), grad_ys=transformed_one_hot) + loss.inputs, + nest.flatten(tensors), + grad_ys=transformed_one_hot, + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) return zip(*grads_all) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index a6fdf01fe7d06a1719aef1f3c329a5587add651a..1ccb9e040f2bb6bcfd217886918abd40e3cc1cfb 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -38,6 +38,7 @@ from __future__ import division from __future__ import print_function import abc +import enum # pylint: disable=g-bad-import-order import six @@ -52,14 +53,54 @@ from tensorflow.python.ops import math_ops # damping /= num_replications ** NORMALIZE_DAMPING_POWER NORMALIZE_DAMPING_POWER = 1.0 +# Methods for adjusting damping for FisherBlocks. See +# _compute_pi_adjusted_damping() for details. +PI_OFF_NAME = "off" +PI_TRACENORM_NAME = "tracenorm" +PI_TYPE = PI_TRACENORM_NAME -def set_global_constants(normalize_damping_power=None): + +def set_global_constants(normalize_damping_power=None, pi_type=None): """Sets various global constants used by the classes in this module.""" global NORMALIZE_DAMPING_POWER + global PI_TYPE if normalize_damping_power is not None: NORMALIZE_DAMPING_POWER = normalize_damping_power + if pi_type is not None: + PI_TYPE = pi_type + + +def _compute_pi_tracenorm(left_cov, right_cov): + """Computes the scalar constant pi for Tikhonov regularization/damping. + + pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) + See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. + + Args: + left_cov: The left Kronecker factor "covariance". + right_cov: The right Kronecker factor "covariance". + + Returns: + The computed scalar constant pi for these Kronecker Factors (as a Tensor). + """ + # Instead of dividing by the dim of the norm, we multiply by the dim of the + # other norm. This works out the same in the ratio. + left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] + right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] + return math_ops.sqrt(left_norm / right_norm) + + +def _compute_pi_adjusted_damping(left_cov, right_cov, damping): + + if PI_TYPE == PI_TRACENORM_NAME: + pi = _compute_pi_tracenorm(left_cov, right_cov) + return (damping * pi, damping / pi) + + elif PI_TYPE == PI_OFF_NAME: + return (damping, damping) + @six.add_metaclass(abc.ABCMeta) class FisherBlock(object): @@ -153,7 +194,7 @@ class FullFB(FisherBlock): self._factor.register_damped_inverse(damping) def multiply_inverse(self, vector): - inverse = self._factor.get_inverse(self._damping) + inverse = self._factor.get_damped_inverse(self._damping) out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector)) return utils.column_to_tensors(vector, out_flat) @@ -411,7 +452,7 @@ class ConvDiagonalFB(FisherBlock): (self._strides[1] * self._strides[2])) if NORMALIZE_DAMPING_POWER: - damping /= self._num_locations ** NORMALIZE_DAMPING_POWER + damping /= self._num_locations**NORMALIZE_DAMPING_POWER self._damping = damping self._factor = self._layer_collection.make_or_get_factor( @@ -465,11 +506,10 @@ class KroneckerProductFB(FisherBlock): Args: damping: The base damping factor (float or Tensor) for the damped inverse. """ - pi = utils.compute_pi(self._input_factor.get_cov(), - self._output_factor.get_cov()) - - self._input_damping = math_ops.sqrt(damping) * pi - self._output_damping = math_ops.sqrt(damping) / pi + self._input_damping, self._output_damping = _compute_pi_adjusted_damping( + self._input_factor.get_cov(), + self._output_factor.get_cov(), + damping**0.5) self._input_factor.register_damped_inverse(self._input_damping) self._output_factor.register_damped_inverse(self._output_damping) @@ -487,8 +527,9 @@ class KroneckerProductFB(FisherBlock): return 1.0 def multiply_inverse(self, vector): - left_factor_inv = self._input_factor.get_inverse(self._input_damping) - right_factor_inv = self._output_factor.get_inverse(self._output_damping) + left_factor_inv = self._input_factor.get_damped_inverse(self._input_damping) + right_factor_inv = self._output_factor.get_damped_inverse( + self._output_damping) reshaped_vector = utils.layer_params_to_mat2d(vector) reshaped_out = math_ops.matmul(left_factor_inv, math_ops.matmul(reshaped_vector, @@ -720,3 +761,260 @@ def _concat_along_batch_dim(tensor_list): def _num_conv_locations(input_shape, strides): """Returns the number of locations a Conv kernel is applied to.""" return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) + + +class FullyConnectedMultiIndepFB(KroneckerProductFB): + """FisherBlock for fully-connected layers that share parameters. + """ + + def __init__(self, layer_collection, inputs, outputs, has_bias=False): + """Creates a FullyConnectedMultiIndepFB block. + + Args: + layer_collection: LayerCollection instance. + inputs: list or tuple of Tensors. Each Tensor has shape [batch_size, + inputs_size]. + outputs: list or tuple of Tensors. Each Tensor has shape [batch_size, + outputs_size]. + has_bias: bool. If True, estimates Fisher with respect to a bias + parameter as well as the layer's parameters. + """ + + assert len(inputs) == len(outputs) + # We need to make sure inputs and outputs are tuples and not lists so that + # they get hashed by layer_collection.make_or_get_factor properly. + self._inputs = tuple(inputs) + self._outputs = tuple(outputs) + self._has_bias = has_bias + self._num_uses = len(inputs) + + super(FullyConnectedMultiIndepFB, self).__init__(layer_collection) + + @property + def num_registered_minibatches(self): + # TODO(b/69411207): Add support for registering additional minibatches. + return 1 + + def instantiate_factors(self, grads_list, damping): + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, + ((self._inputs,), self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list,)) + + if NORMALIZE_DAMPING_POWER: + damping /= self._num_uses**NORMALIZE_DAMPING_POWER + + self._register_damped_input_and_output_inverses(damping) + + @property + def _renorm_coeff(self): + return self._num_uses + + def tensors_to_compute_grads(self): + return self._outputs + + def num_inputs(self): + return len(self._inputs) + + +class SeriesFBApproximation(enum.IntEnum): + """See FullyConnectedSeriesFB.__init__ for description and usage.""" + option1 = 1 + option2 = 2 + + +class FullyConnectedSeriesFB(FisherBlock): + """FisherBlock for fully-connected layers that share parameters across time. + + See the following preprint for details: + https://openreview.net/pdf?id=HyMTkQZAb + + See the end of the appendix of the paper for a pseudo-code of the + algorithm being implemented by multiply_inverse here. Note that we are + using pre-computed versions of certain matrix-matrix products to speed + things up. This is explicitly explained wherever it is done. + """ + + def __init__(self, + layer_collection, + inputs, + outputs, + has_bias=False, + option=SeriesFBApproximation.option2): + """Constructs a new `FullyConnectedSeriesFB`. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + inputs: List of tensors of shape [batch_size, input_size]. + Inputs to the layer. + outputs: List of tensors of shape [batch_size, input_size]. + Outputs of the layer (before activations). + has_bias: Whether the layer includes a bias parameter. + option: A `SeriesFBApproximation` specifying the simplifying assumption + to be used in this block. `option1` approximates the cross-covariance + over time as a symmetric matrix, while `option2` makes + the assumption that training sequences are infinitely long. See section + 3.5 of the paper for more details. + """ + + assert len(inputs) == len(outputs) + # We need to make sure inputs and outputs are tuples and not lists so that + # they get hashed by layer_collection.make_or_get_factor properly. + self._inputs = tuple(inputs) + self._outputs = tuple(outputs) + self._has_bias = has_bias + self._num_timesteps = len(inputs) + self._option = option + + super(FullyConnectedSeriesFB, self).__init__(layer_collection) + + @property + def num_registered_minibatches(self): + # TODO(b/69411207): Add support for registering additional minibatches. + return 1 + + def instantiate_factors(self, grads_list, damping): + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, ((self._inputs,), self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list,)) + + if NORMALIZE_DAMPING_POWER: + damping /= self._num_timesteps**NORMALIZE_DAMPING_POWER + + self._damping_input, self._damping_output = _compute_pi_adjusted_damping( + self._input_factor.get_cov(), + self._output_factor.get_cov(), + damping**0.5) + + if self._option == SeriesFBApproximation.option1: + self._input_factor.register_option1quants(self._damping_input) + self._output_factor.register_option1quants(self._damping_output) + elif self._option == SeriesFBApproximation.option2: + self._input_factor.register_option2quants(self._damping_input) + self._output_factor.register_option2quants(self._damping_output) + else: + raise ValueError( + "Unrecognized FullyConnectedSeriesFB approximation: {}".format( + self._option)) + + def multiply_inverse(self, vector): + # pylint: disable=invalid-name + + Z = utils.layer_params_to_mat2d(vector) + + # Derivations were done for "batch_dim==1" case so we need to convert to + # that orientation: + Z = array_ops.transpose(Z) + + if self._option == SeriesFBApproximation.option1: + + # Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G. + L_A, psi_A = self._input_factor.get_option1quants(self._damping_input) + L_G, psi_G = self._output_factor.get_option1quants(self._damping_output) + + def gamma(x): + # We are assuming that each case has the same number of time-steps. + # If this stops being the case one shouldn't simply replace this T + # with its average value. Instead, one needs to go back to the + # definition of the gamma function from the paper. + T = self._num_timesteps + return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) + + # Y = gamma( psi_G*psi_A^T ) (computed element-wise) + # Even though Y is Z-independent we are recomputing it from the psi's + # each since Y depends on both A and G quantities, and it is relatively + # cheap to compute. + Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) + + # Z = L_G^T * Z * L_A + # This is equivalent to the following computation from the original + # pseudo-code: + # Z = G0^(-1/2) * Z * A0^(-1/2) + # Z = U_G^T * Z * U_A + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) + + # Z = Z .* Y + Z *= Y + + # Z = L_G * Z * L_A^T + # This is equivalent to the following computation from the original + # pseudo-code: + # Z = U_G * Z * U_A^T + # Z = G0^(-1/2) * Z * A0^(-1/2) + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) + + elif self._option == SeriesFBApproximation.option2: + + # Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1), + # and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G. + P_A, K_A, mu_A = self._input_factor.get_option2quants(self._damping_input) + P_G, K_G, mu_G = self._output_factor.get_option2quants( + self._damping_output) + + # Our approach differs superficially from the pseudo-code in the paper + # in order to reduce the total number of matrix-matrix multiplies. + # In particular, the first three computations in the pseudo code are + # Z = G0^(-1/2) * Z * A0^(-1/2) + # Z = Z - hPsi_G^T * Z * hPsi_A + # Z = E_G^T * Z * E_A + # Noting that hPsi = C0^(-1/2) * C1 * C0^(-1/2), so that + # C0^(-1/2) * hPsi = C0^(-1) * C1 * C0^(-1/2) = P^T * C0^(-1/2) + # the entire computation can be written as + # Z = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) + # - hPsi_G^T * G0^(-1/2) * Z * A0^(-1/2) * hPsi_A) * E_A + # = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) + # - G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2)) * E_A + # = E_G^T * G0^(-1/2) * Z * A0^(-1/2) * E_A + # - E_G^T* G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2) * E_A + # = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A + # This final expression is computed by the following two lines: + # Z = Z - P_G * Z * P_A^T + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) + # Z = K_G^T * Z * K_A + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) + + # Z = Z ./ (1*1^T - mu_G*mu_A^T) + # Be careful with the outer product. We don't want to accidentally + # make it an inner-product instead. + tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A + # Prevent some numerical issues by setting any 0.0 eigs to 1.0 + tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype) + Z /= tmp + + # We now perform the transpose/reverse version of the operations + # derived above, whose derivation from the original pseudo-code is + # analgous. + # Z = K_G * Z * K_A^T + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) + + # Z = Z - P_G^T * Z * P_A + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) + + # Z = normalize (1/E[T]) * Z + # Note that this normalization is done because we compute the statistics + # by averaging, not summing, over time. (And the gradient is presumably + # summed over time, not averaged, and thus their scales are different.) + Z /= math_ops.cast(self._num_timesteps, Z.dtype) + + # Convert back to the "batch_dim==0" orientation. + Z = array_ops.transpose(Z) + + return utils.mat2d_to_layer_params(vector, Z) + + # pylint: enable=invalid-name + + def multiply(self, vector): + raise NotImplementedError + + def tensors_to_compute_grads(self): + return self._outputs + + def num_inputs(self): + return len(self._inputs) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 4e36813369e69de1d6f13ddb00566bda912244f6..5a6d1a93ff217c3922f45a047b4d548086ac5258 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import abc +import contextlib import numpy as np import six @@ -26,6 +27,8 @@ import six from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import special_math_ops @@ -50,7 +53,22 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 EIGENVALUE_CLIPPING_THRESHOLD = 0.0 -def set_global_constants(init_covariances_at_zero=None, zero_debias=None, +@contextlib.contextmanager +def _maybe_colocate_with(op, colocate_cov_ops_with_inputs): + """Context to colocate with `op` if `colocate_cov_ops_with_inputs`.""" + if colocate_cov_ops_with_inputs: + if isinstance(op, (list, tuple)): + with tf_ops.colocate_with(op[0]): + yield + else: + with tf_ops.colocate_with(op): + yield + else: + yield + + +def set_global_constants(init_covariances_at_zero=None, + zero_debias=None, eigenvalue_decomposition_threshold=None, eigenvalue_clipping_threshold=None): """Sets various global constants used by the classes in this module.""" @@ -85,7 +103,7 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) -def _compute_cov(tensor, normalizer=None): +def _compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. This function is meant to be applied to random matrices for which the true row @@ -93,6 +111,8 @@ def _compute_cov(tensor, normalizer=None): Args: tensor: A 2D Tensor. + tensor_right: An optional 2D Tensor. If provided, this function computes + the matrix product tensor^T * tensor_right instead of tensor^T * tensor. normalizer: optional scalar for the estimator (by default, the normalizer is the number of rows of tensor). @@ -101,9 +121,14 @@ def _compute_cov(tensor, normalizer=None): """ if normalizer is None: normalizer = array_ops.shape(tensor)[0] - cov = (math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( - normalizer, tensor.dtype)) - return (cov + array_ops.transpose(cov)) / math_ops.cast(2, cov.dtype) + if tensor_right is None: + cov = ( + math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( + normalizer, tensor.dtype)) + return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype) + else: + return (math_ops.matmul(tensor, tensor_right, transpose_a=True) / + math_ops.cast(normalizer, tensor.dtype)) def _append_homog(tensor): @@ -119,7 +144,7 @@ def _append_homog(tensor): rank = len(tensor.shape.as_list()) shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0) ones = array_ops.ones(shape, dtype=tensor.dtype) - return array_ops.concat([tensor, ones], axis=rank-1) + return array_ops.concat([tensor, ones], axis=rank - 1) def scope_string_from_params(params): @@ -157,8 +182,8 @@ def scope_string_from_params(params): elif isinstance(param, (tf_ops.Tensor, variables.Variable)): name_parts.append(scope_string_from_name(param)) else: - raise ValueError( - "Encountered an unsupported param type {}".format(type(param))) + raise ValueError("Encountered an unsupported param type {}".format( + type(param))) return "_".join(name_parts) @@ -209,6 +234,10 @@ class FisherFactor(object): """ pass + @abc.abstractproperty + def _dtype(self): + pass + @property def _cov_initializer(self): return covariance_initializer @@ -220,7 +249,8 @@ class FisherFactor(object): "cov", initializer=self._cov_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) @abc.abstractmethod def _compute_new_cov(self, idx=0): @@ -240,9 +270,10 @@ class FisherFactor(object): return moving_averages.assign_moving_average( self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + @abc.abstractmethod def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" - return [] + pass def get_cov(self): return self._cov @@ -257,6 +288,13 @@ class InverseProvidingFactor(FisherFactor): _cov_shape properties. """ + # TODO(b/69108481): This class (and its subclasses) should be refactored to + # serve the matrix quantities it computes as both (potentially stale) + # variables, updated by the inverse update ops, and fresh values stored in + # tensors that recomputed once every session.run() call. Currently matpower + # and damp_inverse have the former behavior, while eigendecomposition has + # the latter. + def __init__(self): self._inverses_by_damping = {} self._matpower_by_exp_and_damping = {} @@ -267,6 +305,10 @@ class InverseProvidingFactor(FisherFactor): def register_damped_inverse(self, damping): """Registers a damped inverse needed by a FisherBlock. + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_inverse. + Args: damping: The damping value (float or Tensor) for this factor. """ @@ -277,12 +319,17 @@ class InverseProvidingFactor(FisherFactor): "inv_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) self._inverses_by_damping[damping] = inv def register_matpower(self, exp, damping): """Registers a matrix power needed by a FisherBlock. + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_matpower. + Args: exp: The exponent (float or Tensor) to raise the matrix to. damping: The damping value (float or Tensor). @@ -295,57 +342,78 @@ class InverseProvidingFactor(FisherFactor): "matpower_exp{}_damp{}".format(exp_string, damping_string), initializer=inverse_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) self._matpower_by_exp_and_damping[(exp, damping)] = matpower def register_eigendecomp(self): - """Registers that an eigendecomposition is needed by a FisherBlock.""" + """Registers an eigendecomposition. + + Unlike register_damp_inverse and register_matpower this doesn't create + any variables or inverse ops. Instead it merely makes tensors containing + the eigendecomposition available to anyone that wants them. They will be + recomputed (once) for each session.run() call (when they needed by some op). + """ if not self._eigendecomp: - self._eigendecomp = linalg_ops.self_adjoint_eig(self._cov) + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) + + # The matrix self._cov is positive semidefinite by construction, but the + # numerical eigenvalues could be negative due to numerical errors, so here + # we clip them to be at least FLAGS.eigenvalue_clipping_threshold + clipped_eigenvalues = math_ops.maximum(eigenvalues, + EIGENVALUE_CLIPPING_THRESHOLD) + self._eigendecomp = (clipped_eigenvalues, eigenvectors) def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" - ops = super(InverseProvidingFactor, self).make_inverse_update_ops() + ops = [] num_inverses = len(self._inverses_by_damping) matrix_power_registered = bool(self._matpower_by_exp_and_damping) - use_eig = (self._eigendecomp or matrix_power_registered or - num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) + use_eig = ( + self._eigendecomp or matrix_power_registered or + num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) if use_eig: self.register_eigendecomp() # ensures self._eigendecomp is set eigenvalues, eigenvectors = self._eigendecomp # pylint: disable=unpacking-non-sequence - # The matrix self._cov is positive semidefinite by construction, but the - # numerical eigenvalues could be negative due to numerical errors, so here - # we clip them to be at least EIGENVALUE_CLIPPING_THRESHOLD. - clipped_eigenvalues = math_ops.maximum(eigenvalues, - EIGENVALUE_CLIPPING_THRESHOLD) - for damping, inv in self._inverses_by_damping.items(): ops.append( inv.assign( - math_ops.matmul(eigenvectors / (clipped_eigenvalues + damping), + math_ops.matmul(eigenvectors / (eigenvalues + damping), array_ops.transpose(eigenvectors)))) for (exp, damping), matpower in self._matpower_by_exp_and_damping.items(): ops.append( matpower.assign( - math_ops.matmul(eigenvectors * (clipped_eigenvalues + damping)** - exp, array_ops.transpose(eigenvectors)))) + math_ops.matmul(eigenvectors * + (eigenvalues + damping)**exp, + array_ops.transpose(eigenvectors)))) + # These ops share computation and should be run on a single device. + ops = [control_flow_ops.group(*ops)] else: for damping, inv in self._inverses_by_damping.items(): ops.append(inv.assign(utils.posdef_inv(self._cov, damping))) return ops - def get_inverse(self, damping): + def get_damped_inverse(self, damping): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). return self._inverses_by_damping[damping] def get_matpower(self, exp, damping): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). return self._matpower_by_exp_and_damping[(exp, damping)] def get_eigendecomp(self): + # Unlike get_inverse and get_matpower this doesn't retrieve a stored + # variable, but instead always computes a fresh version from the current + # value of get_cov(). return self._eigendecomp @@ -356,12 +424,21 @@ class FullFactor(InverseProvidingFactor): to any type of parameter in principle, but has very high variance. """ - def __init__(self, params_grads, batch_size): + def __init__(self, + params_grads, + batch_size, + colocate_cov_ops_with_inputs=False): self._batch_size = batch_size + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs self._orig_params_grads_name = scope_string_from_params( [params_grads, self._batch_size]) - self._params_grads_flat = tuple( - utils.tensors_to_column(params_grad) for params_grad in params_grads) + params_grads_flat = [] + for params_grad in params_grads: + with _maybe_colocate_with(params_grad, + self._colocate_cov_ops_with_inputs): + col = utils.tensors_to_column(params_grad) + params_grads_flat.append(col) + self._params_grads_flat = tuple(params_grads_flat) super(FullFactor, self).__init__() @property @@ -377,11 +454,17 @@ class FullFactor(InverseProvidingFactor): def _num_sources(self): return len(self._params_grads_flat) + @property + def _dtype(self): + return self._params_grads_flat[0].dtype + def _compute_new_cov(self, idx=0): # This will be a very basic rank 1 estimate - return ((self._params_grads_flat[idx] * array_ops.transpose( - self._params_grads_flat[idx])) / math_ops.cast( - self._batch_size, self._params_grads_flat[idx].dtype)) + with _maybe_colocate_with(self._params_grads_flat[idx], + self._colocate_cov_ops_with_inputs): + return ((self._params_grads_flat[idx] * array_ops.transpose( + self._params_grads_flat[idx])) / math_ops.cast( + self._batch_size, self._params_grads_flat[idx].dtype)) class DiagonalFactor(FisherFactor): @@ -394,6 +477,9 @@ class DiagonalFactor(FisherFactor): def _cov_initializer(self): return diagonal_covariance_initializer + def make_inverse_update_ops(self): + return [] + class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. @@ -402,10 +488,19 @@ class NaiveDiagonalFactor(DiagonalFactor): to any type of parameter in principle, but has very high variance. """ - def __init__(self, params_grads, batch_size): + def __init__(self, + params_grads, + batch_size, + colocate_cov_ops_with_inputs=False): self._batch_size = batch_size - self._params_grads = tuple( - utils.tensors_to_column(params_grad) for params_grad in params_grads) + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs + params_grads_flat = [] + for params_grad in params_grads: + with _maybe_colocate_with(params_grad, + self._colocate_cov_ops_with_inputs): + col = utils.tensors_to_column(params_grad) + params_grads_flat.append(col) + self._params_grads = tuple(params_grads_flat) self._orig_params_grads_name = scope_string_from_params( [self._params_grads, self._batch_size]) super(NaiveDiagonalFactor, self).__init__() @@ -422,9 +517,15 @@ class NaiveDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._params_grads) + @property + def _dtype(self): + return self._params_grads[0].dtype + def _compute_new_cov(self, idx=0): - return (math_ops.square(self._params_grads[idx]) / math_ops.cast( - self._batch_size, self._params_grads[idx].dtype)) + with _maybe_colocate_with(self._params_grads[idx], + self._colocate_cov_ops_with_inputs): + return (math_ops.square(self._params_grads[idx]) / math_ops.cast( + self._batch_size, self._params_grads[idx].dtype)) class FullyConnectedDiagonalFactor(DiagonalFactor): @@ -440,7 +541,11 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, has_bias=False): + def __init__(self, + inputs, + outputs_grads, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Instantiate FullyConnectedDiagonalFactor. Args: @@ -449,18 +554,22 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): outputs_grads: List of Tensors of shape [batch_size, output_size]. Gradient of loss with respect to layer's preactivations. has_bias: bool. If True, append '1' to each input. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._outputs_grads = outputs_grads + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs self._batch_size = array_ops.shape(inputs)[0] - self._orig_tensors_name = scope_string_from_params((inputs,) + - tuple(outputs_grads)) + self._orig_tensors_name = scope_string_from_params( + (inputs,) + tuple(outputs_grads)) # Note that we precompute the required operations on the inputs since the # inputs don't change with the 'idx' argument to _compute_new_cov. (Only # the target entry of _outputs_grads changes with idx.) - if has_bias: - inputs = _append_homog(inputs) - self._squared_inputs = math_ops.square(inputs) + with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): + if has_bias: + inputs = _append_homog(inputs) + self._squared_inputs = math_ops.square(inputs) super(FullyConnectedDiagonalFactor, self).__init__() @@ -476,17 +585,23 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): # The well-known special formula that uses the fact that the entry-wise # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. - new_cov = math_ops.matmul( - self._squared_inputs, - math_ops.square(self._outputs_grads[idx]), - transpose_a=True) - new_cov /= math_ops.cast(self._batch_size, new_cov.dtype) - return new_cov + with _maybe_colocate_with(self._squared_inputs, + self._colocate_cov_ops_with_inputs): + new_cov = math_ops.matmul( + self._squared_inputs, + math_ops.square(self._outputs_grads[idx]), + transpose_a=True) + new_cov /= math_ops.cast(self._batch_size, new_cov.dtype) + return new_cov class ConvDiagonalFactor(DiagonalFactor): @@ -494,8 +609,14 @@ class ConvDiagonalFactor(DiagonalFactor): # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, filter_shape, strides, padding, - has_bias=False): + def __init__(self, + inputs, + outputs_grads, + filter_shape, + strides, + padding, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Creates a ConvDiagonalFactor object. Args: @@ -510,29 +631,36 @@ class ConvDiagonalFactor(DiagonalFactor): padding: The padding in this layer (1-D of Tensor length 4). has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._filter_shape = filter_shape self._has_bias = has_bias self._outputs_grads = outputs_grads + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - self._orig_tensors_name = scope_string_from_name((inputs,) - + tuple(outputs_grads)) + self._orig_tensors_name = scope_string_from_name( + (inputs,) + tuple(outputs_grads)) # Note that we precompute the required operations on the inputs since the # inputs don't change with the 'idx' argument to _compute_new_cov. (Only # the target entry of _outputs_grads changes with idx.) - filter_height, filter_width, _, _ = self._filter_shape - patches = array_ops.extract_image_patches( - inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=strides, - rates=[1, 1, 1, 1], - padding=padding) + with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): + filter_height, filter_width, _, _ = self._filter_shape - if has_bias: - patches = _append_homog(patches) + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + patches = array_ops.extract_image_patches( + inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=strides, + rates=[1, 1, 1, 1], + padding=padding) + + if has_bias: + patches = _append_homog(patches) - self._patches = patches + self._patches = patches super(ConvDiagonalFactor, self).__init__() @@ -543,21 +671,29 @@ class ConvDiagonalFactor(DiagonalFactor): @property def _cov_shape(self): filter_height, filter_width, in_channels, out_channels = self._filter_shape - return [filter_height * filter_width * in_channels + self._has_bias, - out_channels] + return [ + filter_height * filter_width * in_channels + self._has_bias, + out_channels + ] @property def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): - outputs_grad = self._outputs_grads[idx] - batch_size = array_ops.shape(self._patches)[0] + with _maybe_colocate_with(self._outputs_grads[idx], + self._colocate_cov_ops_with_inputs): + outputs_grad = self._outputs_grads[idx] + batch_size = array_ops.shape(self._patches)[0] - new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) - new_cov /= math_ops.cast(batch_size, new_cov.dtype) + new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) - return new_cov + return new_cov def _convdiag_sum_of_squares(self, patches, outputs_grad): # This computes the sum of the squares of the per-training-case "gradients". @@ -572,19 +708,24 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): """Kronecker factor for the input or output side of a fully-connected layer. """ - def __init__(self, tensors, has_bias=False): + def __init__(self, + tensors, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Instantiate FullyConnectedKroneckerFactor. Args: tensors: List of Tensors of shape [batch_size, n]. Represents either a layer's inputs or its output's gradients. - has_bias: bool. If True, assume this factor is for the layer's inputs and - append '1' to each row. + has_bias: bool. If True, append '1' to each row. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ # The tensor argument is either a tensor of input activations or a tensor of # output pre-activation gradients. self._has_bias = has_bias self._tensors = tensors + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(FullyConnectedKroneckerFactor, self).__init__() @property @@ -601,11 +742,17 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._tensors) + @property + def _dtype(self): + return self._tensors[0].dtype + def _compute_new_cov(self, idx=0): - tensor = self._tensors[idx] - if self._has_bias: - tensor = _append_homog(tensor) - return _compute_cov(tensor) + with _maybe_colocate_with(self._tensors[idx], + self._colocate_cov_ops_with_inputs): + tensor = self._tensors[idx] + if self._has_bias: + tensor = _append_homog(tensor) + return _compute_cov(tensor) class ConvInputKroneckerFactor(InverseProvidingFactor): @@ -618,7 +765,13 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): Section 3.1 Estimating the factors. """ - def __init__(self, inputs, filter_shape, strides, padding, has_bias=False): + def __init__(self, + inputs, + filter_shape, + strides, + padding, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Initializes ConvInputKroneckerFactor. Args: @@ -630,12 +783,15 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): width_stride, in_channel_stride]. padding: str. Padding method for layer. "SAME" or "VALID". has_bias: bool. If True, append 1 to in_channel. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._filter_shape = filter_shape self._strides = strides self._padding = padding self._has_bias = has_bias self._inputs = inputs + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvInputKroneckerFactor, self).__init__() @property @@ -655,26 +811,34 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return self._inputs.dtype + def _compute_new_cov(self, idx=0): if idx != 0: raise ValueError("ConvInputKroneckerFactor only supports idx = 0") # TODO(jamesmartens): factor this patches stuff out into a utility function - filter_height, filter_width, in_channels, _ = self._filter_shape - patches = array_ops.extract_image_patches( - self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=[1, 1, 1, 1], - padding=self._padding) + with _maybe_colocate_with(self._inputs, self._colocate_cov_ops_with_inputs): + filter_height, filter_width, in_channels, _ = self._filter_shape - flatten_size = (filter_height * filter_width * in_channels) - patches_flat = array_ops.reshape(patches, [-1, flatten_size]) + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + patches = array_ops.extract_image_patches( + self._inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], + padding=self._padding) - if self._has_bias: - patches_flat = _append_homog(patches_flat) + flatten_size = (filter_height * filter_width * in_channels) + patches_flat = array_ops.reshape(patches, [-1, flatten_size]) - return _compute_cov(patches_flat) + if self._has_bias: + patches_flat = _append_homog(patches_flat) + + return _compute_cov(patches_flat) class ConvOutputKroneckerFactor(InverseProvidingFactor): @@ -688,15 +852,18 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): Section 3.1 Estimating the factors. """ - def __init__(self, outputs_grads): + def __init__(self, outputs_grads, colocate_cov_ops_with_inputs=False): """Initializes ConvOutputKroneckerFactor. Args: outputs_grads: list of Tensors. Each Tensor is of shape - [batch_size, height, width, out_channels]. + [batch_size, height, width, out_channels]. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._out_channels = outputs_grads[0].shape.as_list()[3] self._outputs_grads = outputs_grads + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvOutputKroneckerFactor, self).__init__() @property @@ -712,7 +879,286 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): - reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], - [-1, self._out_channels]) - return _compute_cov(reshaped_tensor) + with _maybe_colocate_with(self._outputs_grads[idx], + self._colocate_cov_ops_with_inputs): + reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], + [-1, self._out_channels]) + return _compute_cov(reshaped_tensor) + + +class FullyConnectedMultiKF(InverseProvidingFactor): + """Kronecker factor for a fully connected recurrent layer.""" + + def __init__(self, + tensor_lists, + has_bias=False, + colocate_cov_ops_with_inputs=False): + """Constructs a new `FullyConnectedMultiKF`. + + Args: + tensor_lists: List of lists of Tensors of shape [batch_size, n]. + has_bias: bool. If True, '1' is appended to each row. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. + """ + + self._orig_tensors_name = scope_string_from_params(tensor_lists) + self._batch_size = array_ops.shape(tensor_lists[0][0])[0] + self._num_timesteps = len(tensor_lists[0]) + + tensors = tuple( + array_ops.concat(tensor_list, 0) for tensor_list in tensor_lists) + if has_bias: + tensors = tuple(_append_homog(tensor) for tensor in tensors) + self._tensors = tensors + + self._cov_dt1 = None + self._option1quants_by_damping = {} + self._option2quants_by_damping = {} + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs + + super(FullyConnectedMultiKF, self).__init__() + + @property + def _var_scope(self): + return "ff_fc_multi/" + self._orig_tensors_name + + @property + def _num_sources(self): + return len(self._tensors) + + @property + def _dtype(self): + return self._tensors[0].dtype + + def make_covariance_update_op(self, ema_decay): + with _maybe_colocate_with(self._tensors, + self._colocate_cov_ops_with_inputs): + op = super(FullyConnectedMultiKF, + self).make_covariance_update_op(ema_decay) + + if self._cov_dt1 is not None: + new_cov_dt1 = math_ops.add_n( + tuple( + self._compute_new_cov_dt1(idx) + for idx in range(self._num_sources))) + op2 = moving_averages.assign_moving_average( + self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) + + # TODO(b/69112164): + # It's important that _cov and _cov_dt1 remain consistent with each + # other while the inverse ops are happening. How can we ensure this? + # We will need to add explicit synchronization for this to + # work with asynchronous training. + op = control_flow_ops.group(op, op2) + + return op + + def _compute_new_cov(self, idx=0): + tensor = self._tensors[idx] + normalizer = self._num_timesteps * self._batch_size + return _compute_cov(tensor, normalizer=normalizer) + + def _compute_new_cov_dt1(self, idx=0): + tensor = self._tensors[idx] + normalizer = self._num_timesteps * self._batch_size + tensor_present = tensor[:-self._batch_size, :] + tensor_future = tensor[self._batch_size:, :] + return _compute_cov( + tensor_future, tensor_right=tensor_present, normalizer=normalizer) + + @property + def _cov_shape(self): + size = self._tensors[0].shape[1] + return [size, size] + + @property + def _vec_shape(self): + size = self._tensors[0].shape[1] + return [size] + + def get_option1quants(self, damping): + return self._option1quants_by_damping[damping] + + def get_option2quants(self, damping): + return self._option2quants_by_damping[damping] + + def get_cov_dt1(self): + assert self._cov_dt1 is not None + return self._cov_dt1 + + def register_cov_dt1(self): + """Create a variable representing temporal cross-covariance. + + (This is technically the second moment, not covariance, since it's + not mean subtracted.) + """ + if self._cov_dt1 is None: + with variable_scope.variable_scope(self._var_scope): + self._cov_dt1 = variable_scope.get_variable( + "cov_dt1", + initializer=init_ops.zeros_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + + def register_option1quants(self, damping): + + self.register_eigendecomp() + self.register_cov_dt1() + + if damping not in self._option1quants_by_damping: + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + Lmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + psi = variable_scope.get_variable( + "psi_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + self._option1quants_by_damping[damping] = (Lmat, psi) + + def register_option2quants(self, damping): + + self.register_eigendecomp() + self.register_cov_dt1() + + if damping not in self._option2quants_by_damping: + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + Pmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + Kmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Kmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + mu = variable_scope.get_variable( + "mu_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + self._option2quants_by_damping[damping] = (Pmat, Kmat, mu) + + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + # TODO(b/69918258): Add correctness tests for this method. + # pylint: disable=invalid-name + + ops = super(FullyConnectedMultiKF, self).make_inverse_update_ops() + + if (len(self._option1quants_by_damping) + + len(self._option2quants_by_damping)): + + # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from + # the pseudo-code in the original paper. Because the computations for + # the A and G case are essentially the same they can both be performed by + # the same class (this one). + + C1 = self.get_cov_dt1() + + # Get the eigendecomposition of C0 (= self.get_cov()) + eigen_e, eigen_V = self.get_eigendecomp() + + # TODO(b/69678661): Note, there is an implicit assumption here that C1 + # and C0 (as represented here by its eigen-decomp) are consistent. This + # could fail to be the case if self._cov and self._cov_dt1 are not updated + # consistently, or are somehow read between or during the cov updates. + # Can this possibly happen? Is there a way to prevent it? + + for damping, (Lmat_var, + psi_var) in self._option1quants_by_damping.items(): + + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # The following line imposses the symmetry assumed by "Option 1" on C1. + # Stangely the code can work okay with this line commented out, + # depending on how psd_eig is defined. I'm not sure why. + C1 = (C1 + array_ops.transpose(C1)) / 2.0 + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means \hat{Psi}) + hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0) + + # Compute the decomposition U*diag(psi)*U^T = hPsi + psi, U = utils.posdef_eig(hPsi) + + # L = C0^(-1/2) * U + Lmat = math_ops.matmul(invsqrtC0, U) + + ops.append(Lmat_var.assign(Lmat)) + ops.append(psi_var.assign(psi)) + + for damping, (Pmat_var, Kmat_var, + mu_var) in self._option2quants_by_damping.items(): + + # compute C0^(-1/2) + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # Compute the product C0^(-1/2) * C1 + invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1) + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means \hat{Psi}) + hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0) + + # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi + # Note that we using the notation mu instead of "m" for the eigenvalues. + # Instead of computing the product hPsi^T * hPsi and then doing an + # eigen-decomposition of this we just compute the SVD of hPsi and then + # square the singular values to get the eigenvalues. For a justification + # of this approach, see: + # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition + sqrtmu, _, E = linalg_ops.svd(hPsi) + mu = math_ops.square(sqrtmu) + + # Mathematically, the eigenvalues should not should not exceed 1.0, but + # due to numerical issues, or possible issues with inconsistent + # values of C1 and (the eigen-decomposition of) C0 they might. So + # we enforce this condition. + mu = math_ops.minimum(mu, 1.0) + + # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1) + Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True) + + # K = C_0^(-1/2) * E + Kmat = math_ops.matmul(invsqrtC0, E) + + ops.append(Pmat_var.assign(Pmat)) + ops.append(Kmat_var.assign(Kmat)) + ops.append(mu_var.assign(mu)) + + return [control_flow_ops.group(*ops)] + + # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 4eabb59b3e4e59c1c9ad4e3c1102efacb52dd478..ca42afe6fb2f5c7d7de8b5b087dc11be30a75d5e 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -26,7 +26,9 @@ from __future__ import print_function from collections import defaultdict from collections import OrderedDict +from functools import partial +import math import six from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb @@ -35,20 +37,51 @@ from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest - # Names for various approximations that can be requested for Fisher blocks. APPROX_KRONECKER_NAME = "kron" APPROX_DIAGONAL_NAME = "diagonal" APPROX_FULL_NAME = "full" +_GENERIC_APPROX_TO_BLOCK_TYPES = { + APPROX_FULL_NAME: fb.FullFB, + APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, +} + +_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, + APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, +} + +_CONV2D_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, + APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, +} + +APPROX_KRONECKER_INDEP_NAME = "kron_indep" +APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" +APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" + +_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB, + APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB, + option=1), + APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB, + option=2) +} + # Possible value for 'reuse' keyword argument. Sets 'reuse' to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" -# TODO(jamesmartens): need to add find_canonical_output back into this somewhere + +def ensure_sequence(obj): + """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" + if isinstance(obj, (tuple, list)): + return obj + else: + return (obj,) class LayerParametersDict(OrderedDict): @@ -107,12 +140,23 @@ class LayerCollection(object): sum. """ - def __init__(self, graph=None, name="LayerCollection"): + def __init__(self, + graph=None, + colocate_cov_ops_with_inputs=False, + name="LayerCollection"): self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() + self._linked_parameters = dict( + ) # dict mapping sets of variables to optionally specified approximations. self._graph = graph or ops.get_default_graph() self._loss_dict = {} # {str: LossFunction} self._subgraph = None + self._default_generic_approximation = APPROX_FULL_NAME + self._default_fully_connected_approximation = APPROX_KRONECKER_NAME + self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME + self._default_fully_connected_multi_approximation = ( + APPROX_KRONECKER_SERIES_2_NAME) + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name @@ -122,42 +166,92 @@ class LayerCollection(object): """LossFunctions registered with this LayerCollection.""" return list(self._loss_dict.values()) + @property + def registered_variables(self): + """A tuple of all of the variables currently registered.""" + tuple_of_tuples = (ensure_sequence(key) for key, block + in six.iteritems(self.fisher_blocks)) + flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) + return flat_tuple + + @property + def linked_parameters(self): + """Groups of parameters with an optionally specified approximation. + + Linked parameters can be added using `define_linked_parameters`. + If an approximation is specified, then this approximation will be used + when registering a layer with exactly these parameters, unless an + approximation is specified when calling the registration function. + + Returns: + A `dict` mapping tuples of parameters to an optional string. + """ + return self._linked_parameters + + @property + def default_generic_approximation(self): + return self._default_generic_approximation + + def set_default_generic_approximation(self, value): + if value not in _GENERIC_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for generic variables.".format( + value)) + self._default_generic_approximation = value + + @property + def default_fully_connected_approximation(self): + return self._default_fully_connected_approximation + + def set_default_fully_connected_approximation(self, value): + if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for fully connected layers.".format( + value)) + self._default_fully_connected_approximation = value + + @property + def default_conv2d_approximation(self): + return self._default_convolution_2d_approximation + + def set_default_conv2d_approximation(self, value): + if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for 2d convolutional layers.".format( + value)) + self._default_convolution_2d_approximation = value + + @property + def default_fully_connected_multi_approximation(self): + return self._default_fully_connected_multi_approximation + + def set_default_fully_connected_multi_approximation(self, value): + if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: + raise ValueError("{} is not a valid approximation for a fully-connected " + "multi layer.".format(value)) + self._default_fully_connected_multi_approximation = value + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. - Validation consists of checking whether the key was already registered or - if any of the elements of layer_key (if it's a tuple) were already - registered as part of another tuple (throws an error if so). If any of the - elements were registered by themselves, or as part of tuples that are - subsets of this layer_key, those registrations are first removed. - - If the layer_key is a subset of an existing registration, registration of - the new, smaller layer_key is skipped. - - e.g. If registrations include {'a': foo, ('b', 'c'): bar}, then - - register_layer('a', baz) -> ValueError - - register_layer(('b', 'c', 'd'), baz) -> - {'a': foo, ('b', 'c', 'd'): baz} - - register_layer('b', baz) -> - {'a': foo, ('b', 'c'): bar} (No change) - - register_layer(('a', 'd'), baz) -> - {('a', 'd'): baz, ('b', 'c'): bar} - - register_layer(('b', 'd'), baz) -> ValueError - Args: - layer_key: The key to check for in existing registrations and to register - if valid. - fisher_block: The associated fisher block. - reuse: Method to use for inserting new FisherBlocks. One of True, False, - or VARIABLE_SCOPE. + layer_key: A variable or tuple of variables. The key to check for in + existing registrations and to register if valid. + fisher_block: The associated `FisherBlock`. + reuse: Method to use for inserting new `FisherBlock`s. One of True, False, + or 'VARIABLE_SCOPE'. Raises: - ValueError: If the layer_key was already registered, or if a subset of the - layer_key has already been registered as part of a different tuple. + ValueError: If `layer_key` was already registered and reuse is `False`, + if `layer_key` was registered with a different block type, or if + `layer_key` shares any variables with but is not equal to a previously + registered key. + KeyError: If `reuse` is `True` but `layer_key` was not previously + registered. Returns: - FisherBlock registered under 'layer_key'. May or may not be the same as - 'fisher_block'. + The `FisherBlock` registered under `layer_key`. If `layer_key` was already + registered, this will be the previously registered `FisherBlock`. """ if reuse is VARIABLE_SCOPE: reuse = variable_scope.get_variable_scope().reuse @@ -177,109 +271,84 @@ class LayerCollection(object): # Insert fisher_block into self.fisher_blocks. if layer_key in self.fisher_blocks: raise ValueError("Duplicate registration: {}".format(layer_key)) - if isinstance(layer_key, (tuple, list)): - return self._register_block_with_sequence_key(layer_key, fisher_block) - else: - return self._register_block_with_nonsequence_key(layer_key, fisher_block) - - def _register_block_with_sequence_key(self, layer_key, fisher_block): - """Validates and registers the layer_key if it's a sequence.""" - # Find all keys that are either supersets or subsets of 'layer_key'. - inclusions = { - fisher_elt - for layer_elt in layer_key for fisher_elt in self.fisher_blocks - if self._equal_or_subset(layer_elt, fisher_elt) + # Raise an error if any variable in layer_key has been registered in any + # other blocks. + variable_to_block = { + var: (params, block) + for (params, block) in self.fisher_blocks.items() + for var in ensure_sequence(params) } - - if not inclusions: - self.fisher_blocks[layer_key] = fisher_block - return fisher_block - - result_key = None - for key in inclusions: - fisher_block_key = key if isinstance(key, (tuple, list)) else (key,) - in_existing_only = set(fisher_block_key) - set(layer_key) - in_new_only = set(layer_key) - set(fisher_block_key) - - if in_existing_only and in_new_only: - # Existing and new key have an intersection but neither is a subset of - # the other. This is an error. + for variable in ensure_sequence(layer_key): + if variable in variable_to_block: + prev_key, prev_block = variable_to_block[variable] raise ValueError( - "Inconsistent registration, expected new key to be a subset or " - "superset of the existing key: existing is {}, new is {}".format( - key, layer_key)) - elif in_existing_only and not in_new_only: - # Existing key is strict superset of new key. Return existing - # FisherBlock. - logging.warning("Graph Registration Warning: tried to register " - "a subset ({}) of an already registered tuple " - "({}), skipping".format(layer_key, fisher_block_key)) - assert result_key is None - result_key = key - elif in_new_only and not in_existing_only: - # Existing key is a strict subset of new key. Replace existing - # FisherBlock with new one. - # - # TODO(b/68715045): This is dangerous. If there are existing - # registrations for a minibatch from elsewhere in the graph, they won't - # be re-registered with this new FisherBlock. The type of FisherBlock - # could also change here. - logging.warning( - "Replacing existing FisherBlock for key {} with new FisherBlock " - "for key {}. {} registered minibatches from the existing " - "FisherBlock will not be migrated.".format( - key, layer_key, - self.fisher_blocks[key].num_registered_minibatches)) - self.fisher_blocks.pop(key) - self.fisher_blocks[layer_key] = fisher_block - assert result_key is None - result_key = layer_key - elif not in_new_only and not in_existing_only: - # Existing and new are identical. Reuse the old FisherBlock. - # - # TODO(b/68715045): This is dangerous. If the new FisherBlock has - # existing registered minibatches, they will not be migrated to the - # existing FisherBlock. - assert result_key is None - result_key = key - else: - raise ValueError("Unexpected layer key conflict: {} vs. {}".format( - layer_key, key)) - - return self.fisher_blocks[result_key] - - def _register_block_with_nonsequence_key(self, layer_key, fisher_block): - """Validates and registers the layer_key if it's not a sequence.""" - inclusions = { - fisher_elt - for fisher_elt in self.fisher_blocks - if self._equal_or_subset(layer_key, fisher_elt) - } - - if not inclusions: - self.fisher_blocks[layer_key] = fisher_block - else: - logging.warning("Graph Registration Warning: tried to register " - "variable ({}) but a containing tuple was already " - "registered ({}), skipping".format(layer_key, inclusions)) - + "Attempted to register layer_key {} with block {}, but variable {}" + " was already registered in key {} with block {}.".format( + layer_key, fisher_block, variable, prev_key, prev_block)) + self.fisher_blocks[layer_key] = fisher_block return fisher_block - def _equal_or_subset(self, elt1, elt2): - """Checks if the elements are equal or one is contained in the other.""" - return (elt1 == elt2 or (isinstance(elt1, - (tuple, list)) and elt2 in elt1) or - (isinstance(elt2, (tuple, list)) and elt1 in elt2)) - def get_use_count_map(self): """Returns a dict of variables to their number of registrations.""" + # TODO(b/70283403): Reimplement this in the old way, where each + # registration function would be responsible for incrementing the count. + # Also, this version has a bug: it won't do the right thing for generic + # registration for parameters that are shared. i.e. it won't set the use + # count to infinity. vars_to_uses = defaultdict(int) for key, block in six.iteritems(self.fisher_blocks): - key = key if isinstance(key, (tuple, list)) else (key,) + n = ( + block.num_inputs()*block.num_registered_minibatches if isinstance( + block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB)) + else block.num_registered_minibatches) + key = ensure_sequence(key) for k in key: - vars_to_uses[k] += block.num_registered_minibatches + vars_to_uses[k] += n return vars_to_uses + def check_registration(self, variables): + """Checks that all variable uses have been registered properly. + + Args: + variables: List of variables. + + Raises: + ValueError: If any registered variables are not included in the list. + ValueError: If any variable in the list is not registered. + ValueError: If any variable in the list is registered with the wrong + number of "uses" in the subgraph recorded (vs the number of times that + variable is actually used in the subgraph). + """ + # Note that overlapping parameters (i.e. those that share variables) will + # be caught by layer_collection.LayerParametersDict during registration. + + reg_use_map = self.get_use_count_map() + + error_messages = [] + + for var in variables: + total_uses = self.subgraph.variable_uses(var) + reg_uses = reg_use_map[var] + + if reg_uses == 0: + error_messages.append("Variable {} not registered.".format(var)) + elif (not math.isinf(reg_uses)) and reg_uses != total_uses: + error_messages.append( + "Variable {} registered with wrong number of uses ({} " + "registrations vs {} uses).".format(var, reg_uses, total_uses)) + + num_get_vars = len(reg_use_map) + + if num_get_vars > len(variables): + error_messages.append("{} registered variables were not included in list." + .format(num_get_vars - len(variables))) + + if error_messages: + error_messages = [ + "Found the following errors with variable registration:" + ] + error_messages + raise ValueError("\n\t".join(error_messages)) + def get_blocks(self): return self.fisher_blocks.values() @@ -294,6 +363,49 @@ class LayerCollection(object): def subgraph(self): return self._subgraph + def define_linked_parameters(self, params, approximation=None): + """Identify a set of parameters that should be grouped together. + + During automatic graph scanning, any matches containing variables that have + been identified as part of a linked group will be filtered out unless + the match parameters are exactly equal to the ones specified in the linked + group. + + Args: + params: A variable, or a tuple or list of variables. The variables + to be linked. + approximation: Optional string specifying the type of approximation to use + for these variables. If unspecified, this layer collection's default + approximation for the layer type will be used. + + Raises: + ValueError: If the parameters were already registered in a layer or + identified as part of an incompatible group. + """ + params = frozenset(ensure_sequence(params)) + + # Check if any of the variables in 'params' is already in + # 'self.fisher_blocks.keys()'. + for registered_params, fisher_block in self.fisher_blocks.items(): + registered_params_set = set(ensure_sequence(registered_params)) + for variable in params: + if (variable in registered_params_set and + params != registered_params_set): + raise ValueError( + "Can't link parameters {}, variable {} was already registered in " + "group {} with layer {}".format(params, variable, + registered_params, fisher_block)) + + # Check if any of the variables in 'params' is already in + # 'self.linked_parameters'. + for variable in params: + for other_linked_params in self.linked_parameters: + if variable in other_linked_params: + raise ValueError("Can't link parameters {}, variable {} was already " + "linked in group {}.".format(params, variable, + other_linked_params)) + self._linked_parameters[params] = approximation + def create_subgraph(self): if not self.losses: raise ValueError("Must have at least one registered loss.") @@ -307,11 +419,19 @@ class LayerCollection(object): return math_ops.add_n( tuple(loss.evaluate_on_sample() for loss in self.losses)) + def _get_linked_approx(self, params): + """If params were linked, return their specified approximation.""" + params_set = frozenset(ensure_sequence(params)) + if params_set in self.linked_parameters: + return self.linked_parameters[params_set] + else: + return None + def register_fully_connected(self, params, inputs, outputs, - approx=APPROX_KRONECKER_NAME, + approx=None, reuse=VARIABLE_SCOPE): """Registers a fully connnected layer. @@ -320,11 +440,11 @@ class LayerCollection(object): this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. - outputs: Tensor of shape [batch_size, output_size]. Preactivations + outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + approx: str. One of "kron" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -332,15 +452,15 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - approx_to_block_types = { - APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, - APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, - } + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_fully_connected_approximation - if approx not in approx_to_block_types: + if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) - block_type = approx_to_block_types[approx] + block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx] has_bias = isinstance(params, (tuple, list)) block = self.register_block(params, block_type(self, has_bias), reuse=reuse) @@ -352,7 +472,7 @@ class LayerCollection(object): padding, inputs, outputs, - approx=APPROX_KRONECKER_NAME, + approx=None, reuse=VARIABLE_SCOPE): """Registers a convolutional layer. @@ -366,10 +486,10 @@ class LayerCollection(object): inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, height, width, out_channels]. - Preactivations produced by layer. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + Output produced by layer. + approx: str. One of "kron" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -377,15 +497,16 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - approx_to_block_types = { - APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, - APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, - } - if approx not in approx_to_block_types: + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_conv2d_approximation + + if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) - block_type = approx_to_block_types[approx] + block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block( params, block_type(self, params, strides, padding), reuse=reuse) block.register_additional_minibatch(inputs, outputs) @@ -393,19 +514,16 @@ class LayerCollection(object): def register_generic(self, params, batch_size, - approx=APPROX_DIAGONAL_NAME, + approx=None, reuse=VARIABLE_SCOPE): """Registers a generic layer. Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [kernel_height, - kernel_width, in_channels, out_channels]. Bias should have shape - [out_channels]. + params: Tensor or tuple of Tensors corresponding to the parameters. batch_size: 0-D Tensor. Size of the minibatch. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + approx: str. One of "full" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -413,18 +531,60 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - approx_to_block_types = { - APPROX_FULL_NAME: fb.FullFB, - APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, - } - if approx not in approx_to_block_types: + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_generic_approximation + + if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) - block_type = approx_to_block_types[approx] + block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block(params, block_type(self, params), reuse=reuse) block.register_additional_minibatch(batch_size) + def register_fully_connected_multi(self, params, inputs, outputs, + approx=None): + """Register fully connected layers with shared parameters. + + This can handle general fully-connected layers with shared parameters, but + has specialized approximations to deal with the case where there is a + meaningful linear order to the share instances (such as in an RNN). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs + to layer. In the case of RNNs, one Tensor per time step. + outputs: A list of tensors, the same length as 'inputs', each of shape + [batch_size, output_size]. Outputs produced by layer. In the case of + RNNs, one Tensor per time step. + approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2". + + Raises: + ValueError: For improper value to 'approx'. + """ + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_fully_connected_multi_approximation + has_bias = isinstance(params, (tuple, list)) + + # TODO(b/70283649): something along the lines of find_canonical_output + # should be added back in here (and for the other block types, arguably). + + if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: + raise ValueError("Bad value {} for approx.".format(approx)) + block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx] + + # For now we don't support multiple minibatches for this type of layer, so + # we set reuse=False + self.register_block(params, + block_type(self, inputs, outputs, has_bias=has_bias), + reuse=False) + def register_categorical_predictive_distribution(self, logits, seed=None, @@ -448,10 +608,10 @@ class LayerCollection(object): tf.get_variable_scope().reuse. Raises: - ValueError: If reuse=True and name != None. - ValueError: If reuse=True and seed != None. - KeyError: If reuse=True and no existing LossFunction with 'name' found. - KeyError: If reuse=False and existing LossFunction with 'name' found. + ValueError: If reuse == True and name == None. + ValueError: If reuse == True and seed != None. + KeyError: If reuse == True and no existing LossFunction with 'name' found. + KeyError: If reuse == False and existing LossFunction with 'name' found. """ name = name or self._graph.unique_name( "register_categorical_predictive_distribution") @@ -560,11 +720,14 @@ class LayerCollection(object): try: hash(args) except TypeError: - raise TypeError(( - "Unable to use (cls, args) = ({}, {}) as a key in " - "LayerCollection.fisher_factors. The pair cannot be hashed." - ).format(cls, args)) - - with variable_scope.variable_scope(self._var_scope): - return utils.setdefault(self.fisher_factors, (cls, args), - lambda: cls(*args)) + raise TypeError( + ("Unable to use (cls, args) = ({}, {}) as a key in " + "LayerCollection.fisher_factors. The pair cannot be hashed.").format( + cls, args)) + + key = cls, args + if key not in self.fisher_factors: + colo = self._colocate_cov_ops_with_inputs + with variable_scope.variable_scope(self._var_scope): + self.fisher_factors[key] = cls(*args, colocate_cov_ops_with_inputs=colo) + return self.fisher_factors[key] diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index 3cfde7f9ababab73980e93ea1dd65be1b559712b..e2e5bc3ffea3e52087c24802948bc8260e3b199a 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -56,6 +56,30 @@ class LossFunction(object): """The inputs to the loss function (excluding the targets).""" pass + @property + def input_minibatches(self): + """A `list` of inputs to the loss function, separated by minibatch. + + Typically there will be one minibatch per tower in a multi-tower setup. + Returns a list consisting of `self.inputs` by default; `LossFunction`s + supporting registering multiple minibatches should override this method. + + Returns: + A `list` of `Tensor`s representing + """ + return [self.inputs] + + @property + def num_registered_minibatches(self): + """Number of minibatches registered for this LossFunction. + + Typically equal to the number of towers in a multi-tower setup. + + Returns: + An `int` representing the number of registered minibatches. + """ + return len(self.input_minibatches) + def evaluate(self): """Evaluate the loss function on the targets.""" if self.targets is not None: @@ -75,7 +99,6 @@ class LossFunction(object): Returns: log probability of each target, summed across all targets. """ - pass @abc.abstractmethod @@ -415,8 +438,8 @@ class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype), axis=-1) output_slice = self._var**-0.5 * ones_slice - return insert_slice_in_zeros(output_slice, 1, - int(self._mean.shape[1]), index[0]) + return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]), + index[0]) @property def fisher_factor_inner_shape(self): @@ -474,24 +497,23 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def _fisher_mean(self): - return 1./self._variance + return 1. / self._variance @property def _fisher_mean_factor(self): - return 1./self._scale + return 1. / self._scale @property def _fisher_var(self): - return 1./(2*math_ops.square(self._variance)) + return 1. / (2 * math_ops.square(self._variance)) @property def _fisher_var_factor(self): - return 1./(math_ops.sqrt(2.)*self._variance) + return 1. / (math_ops.sqrt(2.) * self._variance) def multiply_fisher(self, vecs): mean_vec, var_vec = vecs - return (self._fisher_mean * mean_vec, - self._fisher_var * var_vec) + return (self._fisher_mean * mean_vec, self._fisher_var * var_vec) def multiply_fisher_factor(self, vecs): mean_vec, var_vec = self._split(vecs) @@ -511,8 +533,8 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): # Index corresponds to mean parameter. mean_slice = self._fisher_mean_factor[:, index] mean_slice = array_ops.expand_dims(mean_slice, axis=-1) - mean_output = insert_slice_in_zeros(mean_slice, 1, - int(self._mean.shape[1]), index) + mean_output = insert_slice_in_zeros(mean_slice, 1, int( + self._mean.shape[1]), index) var_output = array_ops.zeros_like(mean_output) else: index -= int(self._mean.shape[-1]) @@ -527,13 +549,17 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def fisher_factor_inner_shape(self): - return array_ops.concat([array_ops.shape(self._mean)[:-1], - 2*array_ops.shape(self._mean)[-1:]], axis=0) + return array_ops.concat( + [ + array_ops.shape(self._mean)[:-1], + 2 * array_ops.shape(self._mean)[-1:] + ], + axis=0) @property def fisher_factor_inner_static_shape(self): shape = self._mean.shape.as_list() - return tensor_shape.TensorShape(shape[-1:] + [2*shape[-1]]) + return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]]) def multiply_hessian(self, vector): raise NotImplementedError() @@ -605,6 +631,10 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, def _logits(self): return array_ops.concat(self._logits_components, axis=0) + @property + def input_minibatches(self): + return self._logits_components + @property def targets(self): if all(target is None for target in self._targets_components): @@ -710,8 +740,8 @@ class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, assert len(index) == 1, "Length of index was {}".format(len(index)) probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1) output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice)) - return insert_slice_in_zeros(output_slice, 1, - int(self._logits.shape[1]), index[0]) + return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]), + index[0]) @property def fisher_factor_inner_shape(self): diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index bfa15e0948c96477d9a79dece985bc4b6dafab6f..ecf7f3e4e5ab7d9c151f760fdab733bc3830e37b 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -35,16 +35,20 @@ from tensorflow.python.training import gradient_descent class KfacOptimizer(gradient_descent.GradientDescentOptimizer): """The KFAC Optimizer (https://arxiv.org/abs/1503.05671).""" - def __init__( - self, - learning_rate, - cov_ema_decay, - damping, - layer_collection, - momentum=0., - momentum_type="regular", - norm_constraint=None, - name="KFAC",): + def __init__(self, + learning_rate, + cov_ema_decay, + damping, + layer_collection, + var_list=None, + momentum=0., + momentum_type="regular", + norm_constraint=None, + name="KFAC", + estimation_mode="gradients", + colocate_gradients_with_ops=False, + cov_devices=None, + inv_devices=None): """Initializes the KFAC optimizer with the given settings. Args: @@ -63,6 +67,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): blocks, kronecker factors, and losses associated with the graph. The layer_collection cannot be modified after KfacOptimizer's initialization. + var_list: Optional list or tuple of variables to train. Defaults to the + list of variables collected in the graph under the key + `GraphKeys.TRAINABLE_VARIABLES`. momentum: The momentum value for this optimizer. Only applies when momentum_type is 'regular' or 'adam'. (Default: 0) momentum_type: The type of momentum to use in this optimizer, one of @@ -72,6 +79,18 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): specified value. May only be used with momentum type 'regular'. (Default: None) name: The name for this optimizer. (Default: 'KFAC') + estimation_mode: The type of estimator to use for the Fishers. Can be + 'gradients', 'empirical', 'curvature_propagation', or 'exact'. + (Default: 'gradients'). See the doc-string for FisherEstimator for + more a more detailed description of these options. + colocate_gradients_with_ops: Whether we should request gradients we + compute in the estimator be colocated with their respective ops. + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. Raises: ValueError: If the momentum type is unsupported. @@ -81,12 +100,19 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): or 'adam'. """ - # We may consider determining the set of variables some other way, but for - # now it's just all the trainable variables. - variables = tf_variables.trainable_variables() + variables = var_list + if variables is None: + variables = tf_variables.trainable_variables() - self._fisher_est = est.FisherEstimator(variables, cov_ema_decay, damping, - layer_collection) + self._fisher_est = est.FisherEstimator( + variables, + cov_ema_decay, + damping, + layer_collection, + estimation_mode=estimation_mode, + colocate_gradients_with_ops=colocate_gradients_with_ops, + cov_devices=cov_devices, + inv_devices=inv_devices) momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel"] @@ -101,7 +127,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): raise ValueError("Momentum must be unspecified if using a momentum_type " "other than 'regular' or 'adam'.") - self._momentum = ops.convert_to_tensor(momentum, name="momentum") + self._momentum = momentum self._momentum_type = momentum_type self._norm_constraint = norm_constraint @@ -125,16 +151,24 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): return self._fisher_est.damping def minimize(self, *args, **kwargs): - - if "var_list" not in kwargs: - kwargs["var_list"] = tf_variables.trainable_variables() - + kwargs["var_list"] = kwargs.get("var_list") or self.variables if set(kwargs["var_list"]) != set(self.variables): raise ValueError("var_list doesn't match with set of Fisher-estimating " "variables.") - return super(KfacOptimizer, self).minimize(*args, **kwargs) + def compute_gradients(self, *args, **kwargs): + # args[1] could be our var_list + if len(args) > 1: + var_list = args[1] + else: + kwargs["var_list"] = kwargs.get("var_list") or self.variables + var_list = kwargs["var_list"] + if set(var_list) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + return super(KfacOptimizer, self).compute_gradients(*args, **kwargs) + def apply_gradients(self, grads_and_vars, *args, **kwargs): """Applies gradients to variables. @@ -291,14 +325,17 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._batch_size, dtype=fft_precon_grads[0].dtype) # compute the entries of the 2x2 matrix - m_11 = (_inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size - + self.damping * _inner_product_list(precon_grads, precon_grads)) + m_11 = ( + _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(precon_grads, precon_grads)) - m_21 = (_inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size - + self.damping * _inner_product_list(prev_updates, precon_grads)) + m_21 = ( + _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(prev_updates, precon_grads)) - m_22 = (_inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size - + self.damping * _inner_product_list(prev_updates, prev_updates)) + m_22 = ( + _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size + + self.damping * _inner_product_list(prev_updates, prev_updates)) def non_zero_prevupd_case(): r"""Computes optimal (alpha, mu) given non-zero previous update. @@ -384,8 +421,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): grads = list(grad for (grad, _) in grads_and_vars) variables = list(var for (_, var) in grads_and_vars) # previous updates are the negative velocities (up to scaling by LR) - prev_updates = list(-self._zeros_slot(var, "velocity", self._name) - for var in variables) + prev_updates = list( + -self._zeros_slot(var, "velocity", self._name) for var in variables) # Compute optimal velocity update parameters according to quadratic model alpha, mu, _ = self._compute_qmodel_hyperparams( diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index 0fd7f5147739f0f46d2ab6a1c284c6dc75f53cc2..cec018e406bc51c07f5cafcc2c38efe7e9601618 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -28,9 +28,9 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops - # Method used for inverting matrices. POSDEF_INV_METHOD = "cholesky" +POSDEF_EIG_METHOD = "self_adjoint" def set_global_constants(posdef_inv_method=None): @@ -64,13 +64,6 @@ class SequenceDict(object): return list(self._dict.items()) -def setdefault(dct, key, thunk): - """Like dict.setdefault but delays evaluation of the value to be set.""" - if key not in dct: - dct[key] = thunk() - return dct[key] - - def tensors_to_column(tensors): """Converts a tensor or list of tensors to a column vector. @@ -169,33 +162,11 @@ def mat2d_to_layer_params(vector_template, mat2d): return array_ops.reshape(mat2d, vector_template.shape) -def compute_pi(left_factor, right_factor): - """Computes the scalar constant pi for Tikhonov regularization/damping. - - pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) - See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. - - Args: - left_factor: The left Kronecker factor Tensor. - right_factor: The right Kronecker factor Tensor. - - Returns: - The computed scalar constant pi for these Kronecker Factors (as a Tensor). - """ - # Instead of dividing by the dim of the norm, we multiply by the dim of the - # other norm. This works out the same in the ratio. - left_norm = math_ops.trace(left_factor) * right_factor.get_shape().as_list()[ - 0] - right_norm = math_ops.trace(right_factor) * left_factor.get_shape().as_list()[ - 0] - return math_ops.sqrt(left_norm / right_norm) - - def posdef_inv(tensor, damping): """Computes the inverse of tensor + damping * identity.""" identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) damping = math_ops.cast(damping, dtype=tensor.dtype) - return posdef_inv_funcs[POSDEF_INV_METHOD](tensor, identity, damping) + return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) def posdef_inv_matrix_inverse(tensor, identity, damping): @@ -209,9 +180,44 @@ def posdef_inv_cholesky(tensor, identity, damping): return linalg_ops.cholesky_solve(chol, identity) -posdef_inv_funcs = { +def posdef_inv_eig(tensor, identity, damping): + """Computes inverse(tensor + damping * identity) with eigendecomposition.""" + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig( + tensor + damping * identity) + return math_ops.matmul( + eigenvectors / eigenvalues, eigenvectors, transpose_b=True) + + +posdef_inv_functions = { "matrix_inverse": posdef_inv_matrix_inverse, "cholesky": posdef_inv_cholesky, + "eig": posdef_inv_eig, +} + + +def posdef_eig(mat): + """Computes the eigendecomposition of a positive semidefinite matrix.""" + return posdef_eig_functions[POSDEF_EIG_METHOD](mat) + + +def posdef_eig_svd(mat): + """Computes the singular values and left singular vectors of a matrix.""" + evals, evecs, _ = linalg_ops.svd(mat) + + return evals, evecs + + +def posdef_eig_self_adjoint(mat): + """Computes eigendecomposition using self_adjoint_eig.""" + evals, evecs = linalg_ops.self_adjoint_eig(mat) + evals = math_ops.abs(evals) # Should be equivalent to svd approach. + + return evals, evecs + + +posdef_eig_functions = { + "self_adjoint": posdef_eig_self_adjoint, + "svd": posdef_eig_svd, } @@ -268,8 +274,8 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): # generated by the first gradients_impl.gradients call. us = [array_ops.zeros_like(y) + float("nan") for y in ys] - dydxs = gradients_impl.gradients(ys, xs, grad_ys=us, - stop_gradients=stop_gradients) + dydxs = gradients_impl.gradients( + ys, xs, grad_ys=us, stop_gradients=stop_gradients) # Deal with strange types that gradients_impl.gradients returns but can't # deal with. @@ -285,3 +291,6 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs) return dysdx + +# TODO(b/69623235): Add a function for finding tensors that share gradients +# to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index ddbb4485ce6967082f1844c6d798c078f1cc303b..8903c90fbce6a890aa419d89b3b79d75f69509fc 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -25,13 +25,11 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "SequenceDict", - "setdefault", "tensors_to_column", "column_to_tensors", "kronecker_product", "layer_params_to_mat2d", "mat2d_to_layer_params", - "compute_pi", "posdef_inv", "posdef_inv_matrix_inverse", "posdef_inv_cholesky", diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 2f1f283811b6cb9e8bfb52ab2052afac1de700cb..852d06e1e3cc8f8deecd15b7436cd4e4a393ad66 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -61,6 +61,7 @@ tf_custom_op_py_library( "python/layers/normalization.py", "python/layers/optimizers.py", "python/layers/regularizers.py", + "python/layers/rev_block_lib.py", "python/layers/summaries.py", "python/layers/target_column.py", "python/layers/utils.py", @@ -376,6 +377,20 @@ py_test( ], ) +py_test( + name = "rev_block_lib_test", + size = "small", + srcs = ["python/layers/rev_block_lib_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":layers_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:init_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index d309ba958ded86afdc1e4bba2ff471a5181cda4e..6c624929f20503054e0258aad8a843f4a201be64 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -42,6 +42,9 @@ See the @{$python/contrib.layers} guide. @@relu @@relu6 @@repeat +@@recompute_grad +@@RevBlock +@@rev_block @@safe_embedding_lookup_sparse @@scale_gradient @@separable_conv2d diff --git a/tensorflow/contrib/layers/python/layers/__init__.py b/tensorflow/contrib/layers/python/layers/__init__.py index 03337f9a5d11784316124442125bb498c4ce9603..f1ae2de68be33880a6fc09957f4d857973902b26 100644 --- a/tensorflow/contrib/layers/python/layers/__init__.py +++ b/tensorflow/contrib/layers/python/layers/__init__.py @@ -28,6 +28,7 @@ from tensorflow.contrib.layers.python.layers.layers import * from tensorflow.contrib.layers.python.layers.normalization import * from tensorflow.contrib.layers.python.layers.optimizers import * from tensorflow.contrib.layers.python.layers.regularizers import * +from tensorflow.contrib.layers.python.layers.rev_block_lib import * from tensorflow.contrib.layers.python.layers.summaries import * from tensorflow.contrib.layers.python.layers.target_column import * from tensorflow.contrib.layers.python.ops.bucketization_op import * diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 226d933d85d91600e36ffb84212703e10455bfbb..092d418c3f232b364e2c6b4d25a4162626ba17f0 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -521,7 +521,7 @@ def sparse_column_with_integerized_feature(column_name, Args: column_name: A string defining sparse column name. - bucket_size: An int that is > 1. The number of buckets. It should be bigger + bucket_size: An int that is >= 1. The number of buckets. It should be bigger than maximum feature. In other words features in this column should be an int64 in range [0, bucket_size) combiner: A string specifying how to reduce if the sparse column is @@ -539,7 +539,7 @@ def sparse_column_with_integerized_feature(column_name, An integerized _SparseColumn definition. Raises: - ValueError: bucket_size is not greater than 1. + ValueError: bucket_size is less than 1. ValueError: dtype is not integer. """ return _SparseColumnIntegerized( diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index fa0047f05d893f6543ddb1680824a32469e13293..78affea44cbfb92523063968dbc1be98841854db 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -97,10 +97,13 @@ def _input_from_feature_columns(columns_to_tensors, trainable, scope, output_rank, - default_name): + default_name, + cols_to_outs=None): """Implementation of `input_from(_sequence)_feature_columns`.""" columns_to_tensors = columns_to_tensors.copy() check_feature_columns(feature_columns) + if cols_to_outs is not None and not isinstance(cols_to_outs, dict): + raise ValueError('cols_to_outs must be a dict unless None') with variable_scope.variable_scope(scope, default_name=default_name, values=columns_to_tensors.values()): @@ -144,6 +147,8 @@ def _input_from_feature_columns(columns_to_tensors, except ValueError as e: raise ValueError('Error creating input layer for column: {}.\n' '{}, {}'.format(column.name, e, ee)) + if cols_to_outs is not None: + cols_to_outs[column] = output_tensors[-1] return array_ops.concat(output_tensors, output_rank - 1) @@ -151,7 +156,8 @@ def input_from_feature_columns(columns_to_tensors, feature_columns, weight_collections=None, trainable=True, - scope=None): + scope=None, + cols_to_outs=None): """A tf.contrib.layers style input layer builder based on FeatureColumns. Generally a single example in training data is described with feature columns. @@ -196,6 +202,8 @@ def input_from_feature_columns(columns_to_tensors, trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). scope: Optional scope for variable_scope. + cols_to_outs: Optional dict from feature column to output tensor, + which is concatenated into the returned tensor. Returns: A Tensor which can be consumed by hidden layers in the neural network. @@ -209,7 +217,8 @@ def input_from_feature_columns(columns_to_tensors, trainable, scope, output_rank=2, - default_name='input_from_feature_columns') + default_name='input_from_feature_columns', + cols_to_outs=cols_to_outs) @experimental diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index fbfa0e32de55edab3c90189ddfe05ab826ac9167..e6bbd86ab722c4e853a59f816bed8a8ac1fe9ede 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -607,6 +607,31 @@ class CreateInputLayersForDNNsTest(test.TestCase): # Verify cross compatibility: Core builder output should equal to contrib. self.assertAllEqual(output.eval().shape, output_core.eval().shape) + def testAllDNNColumnsWithColumnwiseOutputs(self): + sparse_column = feature_column.sparse_column_with_keys( + "ids", ["a", "b", "c", "unseen"]) + real_valued_column = feature_column.real_valued_column("income", 2) + one_hot_column = feature_column.one_hot_column(sparse_column) + embedding_column = feature_column.embedding_column(sparse_column, 10) + features = { + "ids": + sparse_tensor.SparseTensor( + values=["c", "b", "a"], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 1]), + "income": + constant_op.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]]), + } + columns = [one_hot_column, embedding_column, real_valued_column] + cols_to_outs = {} + feature_column_ops.input_from_feature_columns( + features, columns, cols_to_outs=cols_to_outs) + with self.test_session(): + variables_lib.global_variables_initializer().run() + lookup_ops.tables_initializer().run() + for column in columns: + self.assertTrue(column in cols_to_outs) + def testRealValuedColumn(self): real_valued = feature_column.real_valued_column("price") features = {"price": constant_op.constant([[20.], [110], [-3]])} diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index b12a882d9ae88f7cf4f920cfa5872e5de1c67290..51610f21b24f1d40f26630cc1e69ca723d130639 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -79,7 +79,8 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, ``` * To get [Delving Deep into Rectifiers]( - http://arxiv.org/pdf/1502.01852v1.pdf), use (Default):
+ http://arxiv.org/pdf/1502.01852v1.pdf) (also know as the "MSRA + initialization"), use (Default):
`factor=2.0 mode='FAN_IN' uniform=False` * To get [Convolutional Architecture for Fast Feature Embedding]( http://arxiv.org/abs/1408.5093), use:
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index dab5a5297c4a310f7ba0e26dda1d0335e81e567e..0d25a09852544a7eb1ed5eb9c2f3402d9064d91a 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -309,7 +309,6 @@ def _fused_batch_norm(inputs, new_shape = [-1, channels, 1, 1] inputs = array_ops.reshape(inputs, new_shape) inputs_shape = inputs.get_shape() - dtype = inputs.dtype.base_dtype if data_format == DATA_FORMAT_NHWC: params_shape = inputs_shape[-1:] else: @@ -1403,7 +1402,8 @@ def dropout(inputs, noise_shape=None, is_training=True, outputs_collections=None, - scope=None): + scope=None, + seed=None): """Returns a dropout op applied to the input. With probability `keep_prob`, outputs the input element scaled up by @@ -1421,6 +1421,8 @@ def dropout(inputs, Otherwise, inputs is returned. outputs_collections: Collection to add the outputs. scope: Optional scope for name_scope. + seed: A Python integer. Used to create random seeds. See + @{tf.set_random_seed} for behavior. Returns: A tensor representing the output of the operation. @@ -1430,6 +1432,7 @@ def dropout(inputs, inputs = ops.convert_to_tensor(inputs) layer = core_layers.Dropout(rate=1 - keep_prob, noise_shape=noise_shape, + seed=seed, name=sc.name, _scope=sc) outputs = layer.apply(inputs, training=is_training) @@ -2558,7 +2561,10 @@ def separable_convolution2d( regularizer=weights_regularizer, trainable=trainable, collections=weights_collections) - strides = [1, stride_h, stride_w, 1] + strides = [1, 1, stride_h, + stride_w] if data_format.startswith('NC') else [ + 1, stride_h, stride_w, 1 + ] outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding, rate=utils.two_element_tuple(rate), @@ -2648,51 +2654,52 @@ def spatial_softmax(features, ValueError: If unexpected data_format specified. ValueError: If num_channels dimension is unspecified. """ - shape = array_ops.shape(features) - static_shape = features.shape - if data_format == DATA_FORMAT_NHWC: - height, width, num_channels = shape[1], shape[2], static_shape[3] - elif data_format == DATA_FORMAT_NCHW: - num_channels, height, width = static_shape[1], shape[2], shape[3] - else: - raise ValueError('data_format has to be either NCHW or NHWC.') - if num_channels.value is None: - raise ValueError('The num_channels dimension of the inputs to ' - '`spatial_softmax` should be defined. Found `None`.') - - with ops.name_scope(name, 'spatial_softmax', [features]) as name: - # Create tensors for x and y coordinate values, scaled to range [-1, 1]. - pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height), - math_ops.lin_space(-1., 1., num=width), - indexing='ij') - pos_x = array_ops.reshape(pos_x, [height * width]) - pos_y = array_ops.reshape(pos_y, [height * width]) - if temperature is None: - temperature_collections = utils.get_variable_collections( - variables_collections, 'temperature') - temperature = variables.model_variable( - 'temperature', - shape=(), - dtype=dtypes.float32, - initializer=init_ops.ones_initializer(), - collections=temperature_collections, - trainable=trainable) - if data_format == 'NCHW': - features = array_ops.reshape(features, [-1, height * width]) + with variable_scope.variable_scope(name, 'spatial_softmax'): + shape = array_ops.shape(features) + static_shape = features.shape + if data_format == DATA_FORMAT_NHWC: + height, width, num_channels = shape[1], shape[2], static_shape[3] + elif data_format == DATA_FORMAT_NCHW: + num_channels, height, width = static_shape[1], shape[2], shape[3] else: - features = array_ops.reshape( - array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width]) - - softmax_attention = nn.softmax(features/temperature) - expected_x = math_ops.reduce_sum( - pos_x * softmax_attention, [1], keep_dims=True) - expected_y = math_ops.reduce_sum( - pos_y * softmax_attention, [1], keep_dims=True) - expected_xy = array_ops.concat([expected_x, expected_y], 1) - feature_keypoints = array_ops.reshape( - expected_xy, [-1, num_channels.value * 2]) - feature_keypoints.set_shape([None, num_channels.value * 2]) - return feature_keypoints + raise ValueError('data_format has to be either NCHW or NHWC.') + if num_channels.value is None: + raise ValueError('The num_channels dimension of the inputs to ' + '`spatial_softmax` should be defined. Found `None`.') + + with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]): + # Create tensors for x and y coordinate values, scaled to range [-1, 1]. + pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height), + math_ops.lin_space(-1., 1., num=width), + indexing='ij') + pos_x = array_ops.reshape(pos_x, [height * width]) + pos_y = array_ops.reshape(pos_y, [height * width]) + if temperature is None: + temperature_collections = utils.get_variable_collections( + variables_collections, 'temperature') + temperature = variables.model_variable( + 'temperature', + shape=(), + dtype=dtypes.float32, + initializer=init_ops.ones_initializer(), + collections=temperature_collections, + trainable=trainable) + if data_format == 'NCHW': + features = array_ops.reshape(features, [-1, height * width]) + else: + features = array_ops.reshape( + array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width]) + + softmax_attention = nn.softmax(features/temperature) + expected_x = math_ops.reduce_sum( + pos_x * softmax_attention, [1], keep_dims=True) + expected_y = math_ops.reduce_sum( + pos_y * softmax_attention, [1], keep_dims=True) + expected_xy = array_ops.concat([expected_x, expected_y], 1) + feature_keypoints = array_ops.reshape( + expected_xy, [-1, num_channels.value * 2]) + feature_keypoints.set_shape([None, num_channels.value * 2]) + return feature_keypoints def stack(inputs, layer, stack_args, **kwargs): diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 7ccd9d886879f163ba73c7a8f96d0d8962dd8486..ae64b75d939ce0ffab300b01d3cfcb67a9d0da1c 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1345,11 +1345,20 @@ class DropoutTest(test.TestCase): num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) output = _layers.dropout(images) num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) - sess.run(variables_lib.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertLess(num_elem, num_elem_initial / 2 + 0.1) self.assertGreater(num_elem, num_elem_initial / 2 - 0.1) + def testDropoutSeed(self): + """Test that providing the same seed produces the same result.""" + height, width = 10, 10 + with self.test_session() as sess: + images = random_ops.random_uniform( + (5, height, width, 3), seed=1, name='images') + output1 = _layers.dropout(images, seed=1) + output2 = _layers.dropout(images, seed=1) + self.assertAllEqual(*sess.run([output1, output2])) + def testCreateDropoutNoTraining(self): height, width = 3, 3 with self.test_session() as sess: @@ -1358,7 +1367,6 @@ class DropoutTest(test.TestCase): num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) output = _layers.dropout(images, is_training=False) num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) - sess.run(variables_lib.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertEqual(num_elem, num_elem_initial) outputs, inputs = sess.run([output, images]) @@ -1771,7 +1779,8 @@ class BatchNormTest(test.TestCase): dtype = dtypes.float32 height, width = 3, 3 with self.test_session(): - images = np.random.uniform(size=(5, height, width, 3)).astype(dtype.as_numpy_dtype) + images = np.random.uniform(size=(5, height, width, 3)).astype( + dtype.as_numpy_dtype) output = _layers.batch_norm(images, fused=fused) expected_name = ('BatchNorm/FusedBatchNorm' if fused else 'BatchNorm/batchnorm') @@ -2657,18 +2666,18 @@ class BatchNormTest(test.TestCase): # Test case for 11673 with self.test_session() as sess: a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10)) - b_32 = _layers.batch_norm(a_32, center=False, data_format='NCHW', - zero_debias_moving_mean=True) + _layers.batch_norm( + a_32, center=False, data_format='NCHW', zero_debias_moving_mean=True) a_16 = array_ops.placeholder(dtypes.float16, shape=(10, 10, 10, 10)) - b_16 = _layers.batch_norm(a_16, center=False, data_format='NCHW', - zero_debias_moving_mean=True) + _layers.batch_norm( + a_16, center=False, data_format='NCHW', zero_debias_moving_mean=True) sess.run(variables_lib.global_variables_initializer()) def testVariablesAreFloat32(self): height, width = 3, 3 with self.test_session(): - images = random_ops.random_uniform((5, height, width, 3), - seed=1, dtype=dtypes.float16) + images = random_ops.random_uniform( + (5, height, width, 3), seed=1, dtype=dtypes.float16) _layers.batch_norm(images, scale=True) beta = variables.get_variables_by_name('beta')[0] gamma = variables.get_variables_by_name('gamma')[0] @@ -2683,17 +2692,13 @@ class BatchNormTest(test.TestCase): channels = shape[1] images = np.arange(np.product(shape), dtype=dtype).reshape(shape) beta = init_ops.constant_initializer( - np.arange( - 2, channels + 2, dtype=np.float32)) + np.arange(2, channels + 2, dtype=np.float32)) gamma = init_ops.constant_initializer( - np.arange( - 10, channels + 10, dtype=np.float32) * 2.0) + np.arange(10, channels + 10, dtype=np.float32) * 2.0) mean = init_ops.constant_initializer( - np.arange( - 3, channels + 3, dtype=np.float32) * 5.0) + np.arange(3, channels + 3, dtype=np.float32) * 5.0) variance = init_ops.constant_initializer( - np.arange( - 1, channels + 1, dtype=np.float32) * 4.0) + np.arange(1, channels + 1, dtype=np.float32) * 4.0) output = _layers.batch_norm( images, fused=True, @@ -2718,7 +2723,6 @@ class BatchNormTest(test.TestCase): res_16 = self._runFusedBatchNorm(shape, np.float16) self.assertAllClose(res_32, res_16, rtol=1e-3) - def testAdjustmentCreated(self): # Tests that the adjustment is appropriately passed to and used by the core # BN layer. @@ -3322,16 +3326,24 @@ class SeparableConv2dTest(test.TestCase): for model_variable in model_variables: self.assertEqual(trainable, model_variable in trainable_variables) - def testConvNCHW(self): - for num_filters, correct_output_filters in [(None, 6), (8, 8)]: + def testSepConvNCHW(self): + for num_filters, correct_output_filters in zip((None, 5), (6, 5)): with self.test_session(): - batch, height, width = 4, 5, 6 + batch, height, width = 4, 10, 12 + kernel_dim, stride = 3, 2 images = random_ops.random_uniform((batch, 3, height, width), seed=1) output = layers_lib.separable_conv2d( - images, num_filters, [3, 3], 2, padding='VALID', data_format='NCHW') - self.assertListEqual( - output.get_shape().as_list(), [batch, correct_output_filters, - height - 2, width - 2]) + images, + num_outputs=num_filters, + kernel_size=[kernel_dim, kernel_dim], + depth_multiplier=2, + stride=stride, + padding='VALID', + data_format='NCHW') + self.assertListEqual(output.get_shape().as_list(), [ + batch, correct_output_filters, (height - kernel_dim + 1) // stride, + (width - kernel_dim + 1) // stride + ]) class ScaleGradientTests(test.TestCase): diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..123275e1fde047cd3772528641b2e3b09742fbdc --- /dev/null +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -0,0 +1,583 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible Residual Block. + +From +[The Reversible Residual Network: Backpropagation Without Storing +Activations](https://arxiv.org/abs/1707.04585). + +Also contains the @recompute_grad decorator, which recomputes the forward +function on the backwards pass. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import re + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.contrib.framework.python import ops as contrib_framework_ops +from tensorflow.python.framework import function +from tensorflow.python.framework import ops as framework_ops +from tensorflow.python.layers import base +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest + +__all__ = ["rev_block", "RevBlock", "recompute_grad"] + +LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") + + +def _acc_grads(*lists_of_grads): + """Accumulates lists of gradients.""" + acc_grads = [] + for grads in zip(*lists_of_grads): + grads = [g for g in grads if g is not None] + if grads: + acc_grads.append(math_ops.add_n(grads)) + else: + acc_grads.append(None) + return acc_grads + + +def _rev_layer_forward(xs, f, g, f_side_input, g_side_input, + gate_outputs=False): + """Forward for 1 reversible layer.""" + x1, x2 = xs + y1 = x1 + (f(x2, f_side_input) if f_side_input else f(x2)) + y2 = x2 + (g(y1, g_side_input) if g_side_input else g(y1)) + if gate_outputs: + return control_flow_ops.tuple([y1, y2]) + else: + return (y1, y2) + + +def _rev_layer_backward(ys, grad_ys, f, g, f_vars, f_side_input, g_vars, + g_side_input): + """Backprop for 1 layer.""" + y1, y2 = ys + grad_y1, grad_y2 = grad_ys + + # Reconstruct intermediates and inputs (x1, x2) + # stop_gradients required on fn inputs to prevent infinite recursion into this + # grad function on the calls to gradients. + y1_stop = array_ops.stop_gradient(y1) + g_side_input = [array_ops.stop_gradient(t) for t in g_side_input] + gy1 = g(y1_stop, g_side_input) if g_side_input else g(y1_stop) + + x2 = y2 - gy1 + x2_stop = array_ops.stop_gradient(x2) + f_side_input = [array_ops.stop_gradient(t) for t in f_side_input] + fx2 = f(x2_stop, f_side_input) if f_side_input else f(x2_stop) + + x1 = y1 - fx2 + + # Compute gradients wrt to inputs + # dL/dy2 * dG(y1)/y1 + grad_gy1_y2 = gradients_impl.gradients(gy1, y1_stop, grad_y2)[0] + grad_x1 = grad_y1 + grad_gy1_y2 + grad_x2 = ( + gradients_impl.gradients(fx2, x2_stop, grad_y1)[0] + grad_y2 + + gradients_impl.gradients(fx2, x2_stop, grad_gy1_y2)[0]) + + # Compute gradients wrt to vars and side inputs in f and g + grads1 = gradients_impl.gradients(gy1, g_vars + g_side_input, grad_y2) + grad_g_vars, grad_g_side = grads1[:len(g_vars)], grads1[len(g_vars):] + grads2 = gradients_impl.gradients(fx2, f_vars + f_side_input, grad_y1) + grad_f_y1, grad_f_side1 = grads2[:len(f_vars)], grads2[len(f_vars):] + grads3 = gradients_impl.gradients(fx2, f_vars + f_side_input, grad_gy1_y2) + grad_f_y2, grad_f_side2 = grads3[:len(f_vars)], grads3[len(f_vars):] + grad_f_vars = _acc_grads(grad_f_y1, grad_f_y2) + + grad_f_side = _acc_grads(grad_f_side1, grad_f_side2) + + # Put returns in a tuple to ensure a constant memory budget (i.e. don't want + # the subsequent layer to start computing and consuming memory based on a + # subset of these values). + outputs = ((x1, x2), (grad_x1, grad_x2), (grad_f_vars, grad_f_side), + (grad_g_vars, grad_g_side)) + tupled = control_flow_ops.tuple(nest.flatten(outputs)) + return nest.pack_sequence_as(outputs, tupled) + + +def _rev_block_forward(x1, + x2, + f, + g, + num_layers=1, + f_side_input=None, + g_side_input=None, + gate_outputs=False): + """Forward for a series of reversible layers.""" + out = (x1, x2) + for i in xrange(num_layers): + out = _rev_layer_forward( + out, f[i], g[i], f_side_input, g_side_input, gate_outputs=gate_outputs) + + y1, y2 = out + return y1, y2 + + +def _scope_wrap(fn, scope): + + @functools.wraps(fn) + def wrap(*args, **kwargs): + with variable_scope.variable_scope(scope): + return fn(*args, **kwargs) + + return wrap + + +class RevBlock(base.Layer): + """Block of reversible layers. See rev_block.""" + + def __init__(self, + f, + g, + num_layers=1, + f_side_input=None, + g_side_input=None, + use_efficient_backprop=True, + name="revblock", + **kwargs): + super(RevBlock, self).__init__(name=name, **kwargs) + + if isinstance(f, list): + assert len(f) == num_layers + else: + f = [f] * num_layers + + if isinstance(g, list): + assert len(g) == num_layers + else: + g = [g] * num_layers + + f = [_scope_wrap(fn, "revlayer_%d/f" % i) for i, fn in enumerate(f)] + g = [_scope_wrap(fn, "revlayer_%d/g" % i) for i, fn in enumerate(g)] + + self.f = f + self.g = g + + self.num_layers = num_layers + self.f_side_input = f_side_input or [] + self.g_side_input = g_side_input or [] + + self._use_efficient_backprop = use_efficient_backprop + + def call(self, inputs, forward=True): + vs = variable_scope.get_variable_scope() + vars_before = vs.global_variables() + + if forward: + x1, x2 = inputs + out = self._forward(x1, x2) + else: + y1, y2 = inputs + out = self._backward(y1, y2) + + # Add any created variables to the Layer's variable stores + new_vars = vs.global_variables()[len(vars_before):] + train_vars = vs.trainable_variables() + for new_var in new_vars: + if new_var in train_vars: + self._trainable_weights.append(new_var) + else: + self._non_trainable_weights.append(new_var) + + return out + + def forward(self, x1, x2): + return self.apply([x1, x2]) + + def backward(self, y1, y2): + return self.apply([y1, y2], forward=False) + + def build(self, _): + logging.warn("RevBlock constructs its variables on first call, not on " + "build.") + self.built = True + + def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): + """Custom gradient fn for a block of reversible residual layers.""" + side_inputs = inputs[2:] + f_side_idxs = [None] * len(self.f_side_input) + g_side_idxs = [None] * len(self.g_side_input) + assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) + + for i, t in enumerate(side_inputs): + if t in self.f_side_input: + f_side_idxs[self.f_side_input.index(t)] = i + elif t in self.g_side_input: + g_side_idxs[self.g_side_input.index(t)] = i + else: + assert False + + f_vars = [[] for _ in range(self.num_layers)] + g_vars = [[] for _ in range(self.num_layers)] + f_vars_idxs = [[] for _ in range(self.num_layers)] + g_vars_idxs = [[] for _ in range(self.num_layers)] + + for i, t in enumerate(variables): + ref = _underlying_variable_ref(t) + + # Use the name to identify the layer number and function (f or g) + regex = LAYER_RE.match(ref.name) + layer_no = int(regex.group(1)) + fn_name = regex.group(2) + if fn_name == "f": + f_vars[layer_no].append(ref) + f_vars_idxs[layer_no].append(i) + else: + assert fn_name == "g" + g_vars[layer_no].append(ref) + g_vars_idxs[layer_no].append(i) + + f_var_grads = [] + g_var_grads = [] + f_side_grads = [] + g_side_grads = [] + + # Reverse variable containers to go backward + f_vars.reverse() + g_vars.reverse() + f = list(self.f) + g = list(self.g) + f.reverse() + g.reverse() + + with variable_scope.variable_scope(self.scope_name, reuse=True): + for i in xrange(self.num_layers): + ys, grad_ys, f_ret, g_ret = _rev_layer_backward( + ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], + self.g_side_input) + + grad_f_vars, grad_f_side = f_ret + grad_g_vars, grad_g_side = g_ret + f_var_grads.append(grad_f_vars) + g_var_grads.append(grad_g_vars) + f_side_grads.append(grad_f_side) + g_side_grads.append(grad_g_side) + + # Accumulate layer gradients for f_side_input and g_side_input + acc_f_side_grads = _acc_grads(*f_side_grads) + acc_g_side_grads = _acc_grads(*g_side_grads) + + # Use the stored idxs to put gradients in the passed-in order. + side_input_grads = [None] * len(side_inputs) + variable_grads = [None] * len(variables) + + # Variable gradients were collected in reverse layer order. Reverse to match + # idxs. + f_var_grads.reverse() + g_var_grads.reverse() + for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( + zip(g_vars_idxs, g_var_grads)): + for i, grad in zip(idxs, grads): + variable_grads[i] = grad + + for i, grad in zip(f_side_idxs, acc_f_side_grads): + side_input_grads[i] = grad + for i, grad in zip(g_side_idxs, acc_g_side_grads): + side_input_grads[i] = grad + + grad_x1, grad_x2 = grad_ys + return [grad_x1, grad_x2] + side_input_grads, variable_grads + + def _forward(self, x1, x2): + """Run forward through the reversible layers.""" + + side_inputs = [self.f_side_input, self.g_side_input] + flat_side_inputs = nest.flatten(side_inputs) + + custom_grad_fn = ( + self._efficient_grad_fn if self._use_efficient_backprop else None) + + @_fn_with_custom_grad(custom_grad_fn) + def _forward_wrap(x1_, x2_, *flat_side_inputs): + f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs) + return _rev_block_forward( + x1_, + x2_, + self.f, + self.g, + num_layers=self.num_layers, + f_side_input=f_side, + g_side_input=g_side, + gate_outputs=self._use_efficient_backprop) + + return _forward_wrap(x1, x2, *flat_side_inputs) + + def _backward(self, y1, y2): + """Run backward through the reversible layers.""" + + f = list(self.f) + g = list(self.g) + f.reverse() + g.reverse() + + for i in xrange(self.num_layers): + gy1 = g[i](y1, self.g_side_input) if self.g_side_input else g[i](y1) + x2 = y2 - gy1 + fx2 = f[i](x2, self.f_side_input) if self.f_side_input else f[i](x2) + x1 = y1 - fx2 + + y1, y2 = x1, x2 + + return x1, x2 + + +def rev_block(x1, + x2, + f, + g, + num_layers=1, + f_side_input=None, + g_side_input=None, + is_training=True): + """A block of reversible residual layers. + + A reversible residual layer is defined as: + + ``` + y1 = x1 + f(x2, f_side_input) + y2 = x2 + g(y1, g_side_input) + ``` + + A reversible residual block, defined here, is a series of reversible residual + layers. + + Limitations: + * f and g must not close over any Tensors; all side inputs to f and g should + be passed in with f_side_input and g_side_input which will be forwarded to + f and g. + * f and g must not change the dimensionality of their inputs in order for the + addition in the equations above to work. + + Args: + x1: a float Tensor. + x2: a float Tensor. + f: a function, (Tensor) -> (Tensor) (or list of such of length num_layers). + Should not change the shape of the Tensor. Can make calls to get_variable. + See f_side_input if there are side inputs. + g: a function, (Tensor) -> (Tensor) (or list of such of length num_layers). + Should not change the shape of the Tensor. Can make calls to get_variable. + See g_side_input if there are side inputs. + num_layers: int, number of reversible residual layers. Each layer will + apply f and g according to the equations above, with new variables in each + layer. + f_side_input: list of Tensors, side input to f. If not None, signature of f + should be (Tensor, list) -> (Tensor). + g_side_input: list of Tensors, side input to g. If not None, signature of g + should be (Tensor, list) -> (Tensor). + is_training: bool, whether to actually use the efficient backprop codepath. + + Returns: + y1, y2: tuple of float Tensors. + """ + block = RevBlock( + f=f, + g=g, + num_layers=num_layers, + f_side_input=f_side_input, + g_side_input=g_side_input, + use_efficient_backprop=is_training, + _reuse=variable_scope.get_variable_scope().reuse) + return block.forward(x1, x2) + + +def recompute_grad(fn): + """Decorator that recomputes the function on the backwards pass. + + Args: + fn: a function that takes Tensors (all as positional arguments) and returns + a tuple of Tensors. + + Returns: + A wrapped fn that is identical to fn when called, but its activations will + be discarded and recomputed on the backwards pass (i.e. on a call to + tf.gradients). + """ + + @functools.wraps(fn) + def wrapped(*args): + return _recompute_grad(fn, args) + + return wrapped + + +def _recompute_grad(fn, args): + """See recompute_grad.""" + + cached_vs = [] + cached_arg_scope = [] + + def grad_fn(inputs, variables, outputs, output_grads): + """Recompute outputs for gradient computation.""" + del outputs + # Recompute outputs + with framework_ops.control_dependencies(output_grads): + with contrib_framework_ops.arg_scope(cached_arg_scope[0]): + with variable_scope.variable_scope(cached_vs[0], reuse=True): + outputs = fn(*inputs) + + if not (isinstance(outputs, list) or isinstance(outputs, tuple)): + outputs = [outputs] + outputs = list(outputs) + grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) + grad_inputs = grads[:len(inputs)] + grad_vars = grads[len(inputs):] + return grad_inputs, grad_vars + + @_fn_with_custom_grad(grad_fn) + def fn_with_recompute(*args): + cached_vs.append(variable_scope.get_variable_scope()) + # TODO(rsepassi): Rm conditional in TF 1.4 + if hasattr(contrib_framework_ops, "current_arg_scope"): + cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) + else: + cached_arg_scope.append({}) + return fn(*args) + + return fn_with_recompute(*args) + + +def _underlying_variable_ref(t): + """Find the underlying variable ref. + + Traverses through Identity, ReadVariableOp, and Enter ops. + Stops when op type has Variable or VarHandle in name. + + Args: + t: a Tensor + + Returns: + a Tensor that is a variable ref, or None on error. + """ + while t.op.type in ["Identity", "ReadVariableOp", "Enter"]: + t = t.op.inputs[0] + + op_type = t.op.type + if "Variable" in op_type or "VarHandle" in op_type: + return t + else: + return None + + +def _fn_with_custom_grad(grad_fn, use_global_vars=False): + """Decorator to create a subgraph with a custom gradient function. + + The subgraph created by the decorated function is NOT put in a Defun and so + does not suffer from the limitations of the Defun (all subgraph ops on the + same device, no summaries). + + Args: + grad_fn: function with signature + (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars), + all of which are lists of Tensors. + use_global_vars: if True, variables will be the global variables created. + If False, will be the trainable variables. + + Returns: + Decorator for function such that the gradient is defined by grad_fn. + """ + + def dec(fn): + + @functools.wraps(fn) + def wrapped(*args): + return _fn_with_custom_grad_internal( + fn, args, grad_fn, use_global_vars=use_global_vars) + + return wrapped + + return dec + + +def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): + """Create a subgraph with a custom gradient. + + Args: + fn: function that takes inputs as arguments and produces 1 or more Tensors. + inputs: list, will be passed as fn(*inputs). + grad_fn: function with signature + (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars), + all of which are lists of Tensors. + use_global_vars: if True, variables will be the global variables created. + If False, will be the trainable variables. + + Returns: + fn(*inputs) + """ + vs = variable_scope.get_variable_scope() + get_vars_fn = ( + vs.global_variables if use_global_vars else vs.trainable_variables) + len_before_vars = len(get_vars_fn()) + inputs = list(inputs) + outputs = fn(*inputs) + train_vars = get_vars_fn()[len_before_vars:] + + if grad_fn is None: + return outputs + + if not (isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = [outputs] + outputs = list(outputs) + + defun_inputs = [inputs, train_vars, outputs] + + def custom_grad_fn(op, *dys): + """Custom grad fn applying grad_fn for identity Defun.""" + fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as( + defun_inputs, list(op.inputs)) + dys = list(dys) + assert len(fn_outputs) == len(outputs) + assert len(fn_outputs) == len(dys) + + grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys) + grad_outputs = [None] * len(fn_outputs) + return tuple(grad_inputs + grad_vars + grad_outputs) + + # The Defun takes as input the original inputs, the trainable variables + # created in fn, and the outputs. In the forward it passes through the + # outputs. In the backwards, it produces gradients for the original inputs + # and the trainable variables. + in_types = [t.dtype for t in inputs] + out_types = [t.dtype for t in outputs] + var_types = [t.dtype for t in train_vars] + + # Get a unique name for the Defun + with framework_ops.name_scope("identity_custom_grad") as ns: + defun_name = ns + + @function.Defun( + *(in_types + var_types + out_types), + func_name=defun_name, + python_grad_func=custom_grad_fn, + shape_func=lambda _: [t.get_shape() for t in outputs]) + def identity(*args): + _, _, outs = nest.pack_sequence_as(defun_inputs, args) + return tuple([array_ops.identity(t) for t in outs]) + + flat_inputs = nest.flatten(defun_inputs) + id_out = identity(*flat_inputs) + return id_out diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cbcbcd75114a522b95631e4e7e95c1641b0a9987 --- /dev/null +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -0,0 +1,364 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for RevBlock.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.layers.python.layers import rev_block_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import random_seed +from tensorflow.python.layers import convolutional +from tensorflow.python.layers import core as core_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class RevBlockTest(test.TestCase): + CHANNELS = 8 + NUM_LAYERS = 4 + BATCH_SIZE = 16 + + def testForwardBackward(self): + + def f(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + def g(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + x = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + x1, x2 = array_ops.split(x, 2, axis=-1) + + block = rev_block_lib.RevBlock(f, g, num_layers=3) + y1, y2 = block.forward(x1, x2) + x1_inv, x2_inv = block.backward(y1, y2) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv]) + + self.assertAllClose(x1, x1_inv) + self.assertAllClose(x2, x2_inv) + + def testBackwardForward(self): + + def f(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + def g(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + y = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + y1, y2 = array_ops.split(y, 2, axis=-1) + + block = rev_block_lib.RevBlock(f, g, num_layers=3) + x1, x2 = block.backward(y1, y2) + y1_inv, y2_inv = block.forward(x1, x2) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv]) + + self.assertAllClose(y1, y1_inv) + self.assertAllClose(y2, y2_inv) + + def _testRevBlock(self, + x=None, + f=None, + g=None, + f_side_input=None, + g_side_input=None): + random_seed.set_random_seed(1234) + + if f is None: + + def f(x): # pylint: disable=function-redefined + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + if g is None: + + def g(x): # pylint: disable=function-redefined + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + if f_side_input is None: + f_side_input = [] + + if g_side_input is None: + g_side_input = [] + + if x is None: + x = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + x1, x2 = array_ops.split(x, 2, axis=-1) + + with variable_scope.variable_scope("rev_test") as vs: + y1_rev, y2_rev = rev_block_lib.rev_block( + x1, + x2, + f, + g, + f_side_input=f_side_input, + g_side_input=g_side_input, + num_layers=self.NUM_LAYERS) + y_rev = array_ops.concat([y1_rev, y2_rev], axis=1) + fg_vars = vs.trainable_variables() + + num_vars = len(variables.global_variables()) + with variable_scope.variable_scope(vs, reuse=True): + y1, y2 = rev_block_lib.rev_block( + x1, + x2, + f, + g, + f_side_input=f_side_input, + g_side_input=g_side_input, + num_layers=self.NUM_LAYERS, + is_training=False) + y = array_ops.concat([y1, y2], axis=1) + # Ensure no new vars were created - full reuse + assert len(variables.global_variables()) == num_vars + + loss_rev = math_ops.reduce_mean(y_rev + 10.) + loss = math_ops.reduce_mean(y + 10.) + + wrt = [x] + f_side_input + g_side_input + fg_vars + grads_rev = gradients_impl.gradients(loss_rev, wrt) + grads = gradients_impl.gradients(loss, wrt) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) + self.assertAllClose(y_val, yd_val) + for g1, g2 in zip(gd_val, g_val): + self.assertAllClose(g1, g2) + + def testRevBlock(self): + self._testRevBlock() + + def testSideInput(self): + f_side_input = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS // 2]) + + def f(x, side_input): + return core_layers.dense( + x, self.CHANNELS // 2, use_bias=True) + side_input[0] + + self._testRevBlock(f=f, f_side_input=[f_side_input]) + + def testMultipleFns(self): + + def f1(x): + return core_layers.dense(x, self.CHANNELS // 2) + + def f2(x): + return core_layers.dense(x, self.CHANNELS // 2, activation=nn_ops.relu) + + self._testRevBlock(f=[f1, f2, f1, f2]) + + # TODO(rsepassi): Recent change to conv seems to have broken this test. Find + # out why. + def _testConvAndBatchNorm(self): + + x = random_ops.random_uniform( + [self.BATCH_SIZE, 10, self.CHANNELS], dtype=dtypes.float32) + + def f(x): + x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") + x = layers.batch_norm(x, is_training=True) + x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") + x = layers.batch_norm(x, is_training=True) + return x + + self._testRevBlock(x=x, f=f) + + def testReuse(self): + + def f(x): + return core_layers.dense(x, self.CHANNELS // 2) + + def g(x): + return core_layers.dense(x, self.CHANNELS // 2) + + x = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + x1, x2 = array_ops.split(x, 2, axis=-1) + + with variable_scope.variable_scope("test"): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_before = len(variables.global_variables()) + + with variable_scope.variable_scope("test", reuse=True): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_after = len(variables.global_variables()) + self.assertEqual(num_vars_before, num_vars_after) + + loss = math_ops.reduce_mean(y1 + y2) + _ = gradients_impl.gradients(loss, + [x] + variables.trainable_variables()) + + with variable_scope.variable_scope("test", reuse=True): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_after = len(variables.global_variables()) + self.assertEqual(num_vars_before, num_vars_after) + + +class RecomputeTest(test.TestCase): + + def testRecompute(self): + + def layer(x, name=None): + with variable_scope.variable_scope(name, default_name="layer"): + x = layers.layer_norm(x) + x = convolutional.conv1d( + x, + 10, + 1, + use_bias=False, + kernel_initializer=init_ops.constant_initializer(42.42)) + x = nn_ops.relu(x) + return x + + def fn(x): + out = x + for _ in range(3): + out = layer(out) + return out + + @rev_block_lib.recompute_grad + def fn_recompute(x): + return fn(x) + + x = random_ops.random_uniform((3, 1, 3)) + recompute_vars = None + with variable_scope.variable_scope("recompute") as vs: + out1 = math_ops.reduce_sum(fn_recompute(x)) + recompute_vars = vs.trainable_variables() + reg_vars = None + with variable_scope.variable_scope("regular") as vs: + out2 = math_ops.reduce_sum(fn(x)) + reg_vars = vs.trainable_variables() + + grad1 = gradients_impl.gradients(out1, recompute_vars) + grad2 = gradients_impl.gradients(out2, reg_vars) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + outs = sess.run([out1, out2, grad1, grad2]) + self.assertAllClose(outs[0], outs[1]) + for g1, g2 in zip(outs[2], outs[3]): + self.assertAllClose(g1, g2) + + +class FnWithCustomGradTest(test.TestCase): + + def testCorrectness(self): + + w = random_ops.random_uniform([6, 10]) + + def fn(a, b, c): + return core_layers.dense( + a, + 10, + use_bias=False, + kernel_initializer=lambda shape, dtype, partition_info: w + ) + math_ops.matmul(b, c) + + def grad_fn(inputs, trainable_variables, outputs, grad_outputs): + outputs = outputs[0] + grad_outputs = grad_outputs[0] + grad_inputs = gradients_impl.gradients( + outputs, inputs, grad_ys=grad_outputs) + grad_vars = gradients_impl.gradients( + outputs, trainable_variables, grad_ys=grad_outputs) + return grad_inputs, grad_vars + + custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn) + + a = random_ops.random_uniform([11, 6]) + b = random_ops.random_uniform([11, 7]) + c = random_ops.random_uniform([7, 10]) + + out = fn(a, b, c) + custom_out = custom_fn(a, b, c) + self.assertEqual(out.get_shape().as_list(), + custom_out.get_shape().as_list()) + + loss = math_ops.reduce_mean(out) + custom_loss = math_ops.reduce_mean(custom_out) + + grads = gradients_impl.gradients( + loss, [a, b, c] + [variables.trainable_variables()[0]]) + custom_grads = gradients_impl.gradients( + custom_loss, [a, b, c] + [variables.trainable_variables()[1]]) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + out_val, custom_out_val, grads_val, custom_grads_val = sess.run( + [out, custom_out, grads, custom_grads]) + self.assertAllClose(out_val, custom_out_val) + for g1, g2 in zip(grads_val, custom_grads_val): + self.assertAllClose(g1, g2) + + def testCustomGrad(self): + + def fn(a, b, c): + return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(b, c) + + def grad_fn(inputs, trainable_variables, unused_outputs, + unused_grad_outputs): + grad_inputs = [ + array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs) + ] + grad_vars = [ + array_ops.ones_like(t) * (i + len(inputs) + 1.) + for i, t in enumerate(trainable_variables) + ] + return grad_inputs, grad_vars + + a = random_ops.random_uniform([11, 6]) + b = random_ops.random_uniform([11, 7]) + c = random_ops.random_uniform([7, 10]) + w = random_ops.random_uniform([6, 10]) + out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c) + loss = math_ops.reduce_mean(out) + grads = gradients_impl.gradients( + loss, [a, b, c, variables.trainable_variables()[0]]) + expected_grads = [ + array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w]) + ] + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + g_val, eg_val = sess.run([grads, expected_grads]) + for g1, g2 in zip(g_val, eg_val): + self.assertAllClose(g1, g2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 2917a30a1770351a2315a8deb696d1841d260ff0..5df2c77249b81434125d838f896f0ace2a5ee130 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -10,7 +10,7 @@ package(default_visibility = [ "//tensorflow:internal", ]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") py_library( name = "learn", @@ -22,6 +22,8 @@ py_library( exclude = ["python/learn/**/*_test.py"], ), srcs_version = "PY2AND3", + # This library should not depend on sklearn, even though some of the code + # refers to it. (The code handles the presence of sklearn conditionally.) deps = [ "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/framework:framework_py", @@ -152,12 +154,11 @@ py_test( ], ) -py_test( +tf_py_test( name = "experiment_test", size = "medium", srcs = ["python/learn/experiment_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", "//tensorflow/core:protos_all_py", @@ -344,6 +345,7 @@ py_test( srcs = ["python/learn/estimators/dnn_linear_combined_test.py"], shard_count = 4, srcs_version = "PY2AND3", + tags = ["no_oss"], # flaky b/70524820 deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", @@ -459,6 +461,7 @@ py_test( size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], srcs_version = "PY2AND3", + tags = ["noasan"], deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", @@ -713,12 +716,11 @@ py_test( ], ) -py_test( +tf_py_test( name = "graph_io_test", size = "small", srcs = ["python/learn/learn_io/graph_io_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":learn", "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -734,6 +736,7 @@ py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], + grpc_enabled = True, ) py_test( diff --git a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py index 14750961efa30128708430fac038498de0a42118..ef5e620e8f08cffa7c2b945089aa5d150baefefc 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.estimators import composable_model @@ -55,7 +55,7 @@ def _base_model_fn(features, labels, mode, params): raise NotImplementedError def _train_op_fn(loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() assert global_step train_step = model.get_train_step(loss) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index cb15ef23e95d27c737d8ae08065b804bafd39a07..c17b41c0f767e19d9c3635a8f60347a49b297cfb 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -23,7 +23,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec @@ -189,7 +189,7 @@ def _dnn_model_fn(features, labels, mode, params, config=None): """Returns the op to optimize the loss.""" return optimizers.optimize_loss( loss=loss, - global_step=contrib_variables.get_global_step(), + global_step=training_util.get_global_step(), learning_rate=_LEARNING_RATE, optimizer=_get_optimizer(optimizer), gradient_multipliers=( diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 788d2d0b1a58fad16712c968593b40de0d3979f0..05ed8b3409e68ae54e5ef89b3a1592a6f285565b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -30,7 +30,6 @@ import six from google.protobuf import message from tensorflow.contrib import layers -from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework import list_variables @@ -60,6 +59,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -1230,7 +1230,7 @@ class Estimator(BaseEstimator): if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops: model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = ( - metrics_lib.streaming_mean(model_fn_ops.loss)) + metrics_lib.mean(model_fn_ops.loss)) return model_fn_ops def _get_predict_ops(self, features): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py index 248c6c733ffca351c848ba07110ba89928634a23..9d7c1a099aa4be64ca0296fa5b870597dabec7b4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py @@ -23,7 +23,7 @@ import tempfile import numpy as np -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn import models @@ -114,7 +114,7 @@ def linear_model_params_fn(features, labels, mode, params): prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( loss, - variables.get_global_step(), + training_util.get_global_step(), optimizer='Adagrad', learning_rate=params['learning_rate']) return prediction, loss, train_op @@ -129,7 +129,7 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return prediction, loss, train_op @@ -139,7 +139,7 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -150,7 +150,7 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index be2b0cb3ca959323b4de095ca072278f028be301..2a13a84627df35a68a4f04b25ab26ceecad0db0d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -32,7 +32,7 @@ from google.protobuf import text_format from tensorflow.contrib import learn from tensorflow.contrib import lookup -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import experiment @@ -132,7 +132,7 @@ def linear_model_params_fn(features, labels, mode, params): prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( loss, - variables.get_global_step(), + training_util.get_global_step(), optimizer='Adagrad', learning_rate=params['learning_rate']) return prediction, loss, train_op @@ -147,7 +147,7 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return prediction, loss, train_op @@ -157,7 +157,7 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -168,7 +168,7 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction @@ -241,7 +241,7 @@ def _build_estimator_for_resource_export_test(): const = constant_op.constant(-1, dtype=dtypes.int64) table = lookup.MutableHashTable( dtypes.string, dtypes.int64, const, name='LookupTableModel') - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) if mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL): key = constant_op.constant(['key']) value = constant_op.constant([42], dtype=dtypes.int64) @@ -306,7 +306,7 @@ def _model_fn_ops( mode=mode, predictions=constant_op.constant(0.), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1)) + train_op=training_util.get_global_step().assign_add(1)) def _make_input_fn(features, labels): @@ -389,7 +389,7 @@ class EstimatorModelFnTest(test.TestCase): self.assertEqual(expected_param, params) self.assertEqual(model_dir, expected_model_dir) return (constant_op.constant(0.), constant_op.constant(0.), - variables.get_global_step().assign_add(1)) + training_util.get_global_step().assign_add(1)) est = estimator.Estimator(model_fn=_argument_checker, params=expected_param, model_dir=expected_model_dir) @@ -400,7 +400,7 @@ class EstimatorModelFnTest(test.TestCase): def _invalid_model_fn(features, labels): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): loss = 100.0 - w return None, loss, None @@ -415,7 +415,7 @@ class EstimatorModelFnTest(test.TestCase): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') loss = 100.0 - w - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): train_op = w.assign_add(loss / 100.0) predictions = loss @@ -434,7 +434,7 @@ class EstimatorModelFnTest(test.TestCase): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') loss = 100.0 - w - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): train_op = w.assign_add(loss / 100.0) return None, loss, train_op @@ -464,7 +464,7 @@ class EstimatorModelFnTest(test.TestCase): mode=mode, predictions=constant_op.constant(0.), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1), + train_op=training_util.get_global_step().assign_add(1), scaffold=monitored_session.Scaffold(init_fn=_init_fn)) est = estimator.Estimator(model_fn=_model_fn_scaffold) @@ -483,7 +483,7 @@ class EstimatorModelFnTest(test.TestCase): mode=mode, predictions=constant_op.constant([[1.]]), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1), + train_op=training_util.get_global_step().assign_add(1), scaffold=monitored_session.Scaffold(saver=self.mock_saver)) def input_fn(): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py index 1d89dfb55b10b032cab7dcf434d396404d4eb83b..8131e0fde6fea5501cacc4714f53ed8d867ca70f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py @@ -22,7 +22,7 @@ import random import numpy as np -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python import learn from tensorflow.contrib.learn.python.learn import datasets from tensorflow.contrib.learn.python.learn import metric_spec @@ -62,7 +62,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["transformed_x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator = estimator_lib.Estimator( @@ -100,7 +100,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator = estimator_lib.Estimator( @@ -139,7 +139,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator_with_fe_fn = estimator_lib.Estimator( diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index 992b804f59ecd88fedc2fba10d3079f93c4fe83d..8f9d6fc318a357853bdb8e3264f6691b410006b1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -28,7 +28,7 @@ import time import numpy as np from tensorflow.contrib.factorization.python.ops import clustering_ops -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModelFnOps from tensorflow.python.framework import ops @@ -128,7 +128,7 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config): random_seed=params.get('random_seed'), kmeans_plus_plus_num_retries=params.get( 'kmeans_plus_plus_num_retries')).training_graph() - incr_step = state_ops.assign_add(variables.get_global_step(), 1) + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) loss = math_ops.reduce_sum(losses, name=KMeansClustering.LOSS_OP_NAME) summary.scalar('loss/raw', loss) training_op = with_dependencies([training_op, incr_step], loss) diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index f5445ad4e728dbd3904279573771de9454b5d17c..37aa8b339622415d082933cdf66d2472a4119b48 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -26,7 +26,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib @@ -170,7 +170,7 @@ def _linear_model_fn(features, labels, mode, params, config=None): weight_collections=[parent_scope]) def _train_op_fn(loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() my_vars = ops.get_collection(parent_scope) grads = gradients.gradients(loss, my_vars) if gradient_clip_norm: @@ -252,7 +252,7 @@ def sdca_model_fn(features, labels, mode, params): _add_bias_column(feature_columns, features, bias, columns_to_variables) def _train_op_fn(unused_loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() sdca_model, train_op = optimizer.get_train_step(columns_to_variables, weight_column_name, loss_type, features, diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py index 93c62f87e8495f299a8c456574c7b40534186304..656d68b76888d9319c0b9be481f9b0478ac4314c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.estimators import logistic_regressor @@ -57,7 +57,7 @@ def _logistic_regression_model_fn(features, labels, mode): predictions = math_ops.sigmoid(logits) loss = losses.sigmoid_cross_entropy(labels, logits) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return predictions, loss, train_op diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 307db76afe20a7743df16d169270a6f319497eb6..9576ff21c243022276bb0641882dfaf0decf05c0 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.python.estimator import estimator as core_estimator +from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks @@ -46,6 +47,18 @@ from tensorflow.python.util import compat __all__ = ["Experiment"] +def _get_standardized_predicate_fn(predicate_fn): + pred_fn_args = estimator_util.fn_args(predicate_fn) + if "checkpoint_path" not in pred_fn_args: + # pylint: disable=unused-argument + def _pred_fn_wrapper(eval_results, checkpoint_path): + return predicate_fn(eval_results) + + return _pred_fn_wrapper + else: + return predicate_fn + + class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): """Listener that evaluates and exports a model after creating a checkpoint. @@ -140,7 +153,8 @@ class Experiment(object): delay_workers_by_global_step=False, export_strategies=None, train_steps_per_iteration=None, - checkpoint_and_export=False): + checkpoint_and_export=False, + saving_listeners=None): """Constructor for `Experiment`. Creates an Experiment instance. None of the functions passed to this @@ -200,6 +214,9 @@ class Experiment(object): `save_checkpoints_steps`. Also, this parameter leads to the creation of a default `CheckpointSaverHook` instead of a `ValidationMonitor`, so the provided `train_monitors` will need to be adjusted accordingly. + saving_listeners: list of `CheckpointSaverListener` objects. Used by + tf.estimator.Estimator for callbacks that run immediately before or + after checkpoint savings. Raises: ValueError: if `estimator` does not implement Estimator interface, @@ -221,6 +238,9 @@ class Experiment(object): raise ValueError( "`estimator` must implement `tf.contrib.learn.Trainable`" "or `tf.estimator.`Estimator`.") + if saving_listeners is not None: + raise ValueError("`saving_listeners` must be `None` with " + "`tf.contrib.learn.Estimator`.") if isinstance(estimator, tpu_estimator.TPUEstimator): logging.warn( @@ -242,6 +262,7 @@ class Experiment(object): self._eval_delay_secs = eval_delay_secs self._continuous_eval_throttle_secs = continuous_eval_throttle_secs self._checkpoint_and_export = checkpoint_and_export + self._saving_listeners = saving_listeners # Using 1 on a non-cached file system requires a lot of overhead to # read the checkpoint state file. This is particular bad on GCS, so # we use a different default. This is a temporary band-aid, to be @@ -362,9 +383,11 @@ class Experiment(object): logging.info("Waiting %d secs before starting training.", remaining) time.sleep(delay_secs) - return self._call_train(input_fn=self._train_input_fn, - max_steps=self._train_steps, - hooks=self._train_monitors + extra_hooks) + return self._call_train( + input_fn=self._train_input_fn, + max_steps=self._train_steps, + hooks=self._train_monitors + extra_hooks, + saving_listeners=self._saving_listeners) def evaluate(self, delay_secs=None, name=None): """Evaluate on the evaluation data. @@ -436,22 +459,33 @@ class Experiment(object): evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints that have already been evaluated. Default is `True`. continuous_eval_predicate_fn: A predicate function determining whether to - continue eval after each iteration. `predicate_fn` takes the evaluation - results as arguments. At the beginning of evaluation, the passed eval - results will be None so it's expected that the predicate function - handles that gracefully. When `predicate_fn` is not specified, - continuous eval will run in an infinite loop (if `train_steps` is None) - or exit once global step reaches `train_steps`. + continue eval after each iteration. A `predicate_fn` has one of the + following signatures: + * (eval_results) -> boolean + * (eval_results, checkpoint_path) -> boolean + Where `eval_results` is the dictionary of metric evaluations and + checkpoint_path is the path to the checkpoint containing the parameters + on which that evaluation was based. + At the beginning of evaluation, the passed `eval_results` will be None + so it's expected that the predicate function handles that gracefully. + When `predicate_fn` is not specified, continuous eval will run in an + infinite loop (if `train_steps` is None). or exit once global step + reaches `train_steps`. + export: Whether to export from this step. Default is 'True'. Raises: ValueError: if `continuous_eval_predicate_fn` is neither None nor callable. """ - if (continuous_eval_predicate_fn is not None and - not callable(continuous_eval_predicate_fn)): - raise ValueError( - "`continuous_eval_predicate_fn` must be a callable, or None.") + if continuous_eval_predicate_fn is not None: + if not callable(continuous_eval_predicate_fn): + raise ValueError( + "`continuous_eval_predicate_fn` must be a callable, or None.") + predicate_fn = _get_standardized_predicate_fn( + continuous_eval_predicate_fn) + else: + predicate_fn = None if delay_secs is None: delay_secs = self._eval_delay_secs @@ -465,8 +499,10 @@ class Experiment(object): previous_path = None eval_result = None last_warning_time = 0 - while (not continuous_eval_predicate_fn or - continuous_eval_predicate_fn(eval_result)): + while (not predicate_fn or + predicate_fn( + eval_result, + checkpoint_path=previous_path if eval_result else None)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " @@ -672,11 +708,19 @@ class Experiment(object): Args: continuous_eval_predicate_fn: A predicate function determining whether to - continue after each iteration. `predicate_fn` takes the evaluation - results as its arguments. At the beginning of evaluation, the passed - eval results will be None so it's expected that the predicate function - handles that gracefully. When `predicate_fn` is not specified, this will - run in an infinite loop or exit when global_step reaches `train_steps`. + continue eval after each iteration. A `predicate_fn` has one of the + following signatures: + * (eval_results) -> boolean + * (eval_results, checkpoint_path) -> boolean + Where `eval_results` is the dictionary of metric evaluations and + checkpoint_path is the path to the checkpoint containing the parameters + on which that evaluation was based. + At the beginning of evaluation, the passed `eval_results` and + `checkpoint_path` will be None so it's expected that the predicate + function handles that gracefully. + When `predicate_fn` is not specified, continuous eval will run in an + infinite loop (if `train_steps` is None). or exit once global step + reaches `train_steps`. Returns: A tuple of the result of the `evaluate` call to the `Estimator` and the @@ -687,13 +731,18 @@ class Experiment(object): callable. """ - if (continuous_eval_predicate_fn is not None and - not callable(continuous_eval_predicate_fn)): - raise ValueError( - "`continuous_eval_predicate_fn` must be a callable, or None.") + if continuous_eval_predicate_fn is not None: + if not callable(continuous_eval_predicate_fn): + raise ValueError( + "`continuous_eval_predicate_fn` must be a callable, or None.") + predicate_fn = _get_standardized_predicate_fn( + continuous_eval_predicate_fn) + else: + predicate_fn = None - eval_result = None export_results = None + latest_checkpoint = None + eval_result = None # Set the default value for train_steps_per_iteration, which will be # overridden by other settings. @@ -703,8 +752,10 @@ class Experiment(object): elif self._train_steps is not None: train_steps_per_iteration = int(self._train_steps / 10) - while (not continuous_eval_predicate_fn or - continuous_eval_predicate_fn(eval_result)): + while (not predicate_fn or + predicate_fn( + eval_result, + checkpoint_path=latest_checkpoint if eval_result else None)): if self._has_training_stopped(eval_result): # Exits once max steps of training is satisfied. @@ -712,16 +763,21 @@ class Experiment(object): break logging.info("Training model for %s steps", train_steps_per_iteration) - self._call_train(input_fn=self._train_input_fn, - steps=train_steps_per_iteration, - hooks=self._train_monitors) + self._call_train( + input_fn=self._train_input_fn, + steps=train_steps_per_iteration, + hooks=self._train_monitors, + saving_listeners=self._saving_listeners) logging.info("Evaluating model now.") - eval_result = self._call_evaluate(input_fn=self._eval_input_fn, - steps=self._eval_steps, - metrics=self._eval_metrics, - name="one_pass", - hooks=self._eval_hooks) + latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir) + eval_result = self._call_evaluate( + input_fn=self._eval_input_fn, + steps=self._eval_steps, + metrics=self._eval_metrics, + name="one_pass", + checkpoint_path=latest_checkpoint, + hooks=self._eval_hooks) export_results = self._maybe_export(eval_result) return eval_result, export_results @@ -762,9 +818,11 @@ class Experiment(object): Returns: The result of the `evaluate` call to the `Estimator`. """ - self._call_train(input_fn=self._train_input_fn, - steps=1, - hooks=self._train_monitors) + self._call_train( + input_fn=self._train_input_fn, + steps=1, + hooks=self._train_monitors, + saving_listeners=self._saving_listeners) eval_result = self._call_evaluate(input_fn=self._eval_input_fn, steps=1, @@ -792,7 +850,8 @@ class Experiment(object): return server def _call_train(self, _sentinel=None, # pylint: disable=invalid-name, - input_fn=None, steps=None, hooks=None, max_steps=None): + input_fn=None, steps=None, hooks=None, max_steps=None, + saving_listeners=None): if _sentinel is not None: raise ValueError("_call_train should be called with keyword args only") @@ -801,10 +860,12 @@ class Experiment(object): # safe to convert for both cases. hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator) if self._core_estimator_used: - return self._estimator.train(input_fn=input_fn, - steps=steps, - max_steps=max_steps, - hooks=hooks) + return self._estimator.train( + input_fn=input_fn, + steps=steps, + max_steps=max_steps, + hooks=hooks, + saving_listeners=saving_listeners) else: return self._estimator.fit(input_fn=input_fn, steps=steps, diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index fe40d27c445d4f560c96fc9b50ceb0daed30ee93..545d7d8924c0c10544e6113e2968b7ae3d2090fc 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -232,14 +232,19 @@ class ExperimentTest(test.TestCase): def test_train(self): for est in self._estimators_for_tests(): - eval_metrics = 'eval_metrics' if not isinstance( - est, core_estimator.Estimator) else None + if isinstance(est, core_estimator.Estimator): + eval_metrics = None + saving_listeners = 'saving_listeners' + else: + eval_metrics = 'eval_metrics' + saving_listeners = None ex = experiment.Experiment( est, train_input_fn='train_input', train_steps='train_steps', eval_input_fn='eval_input', - eval_metrics=eval_metrics) + eval_metrics=eval_metrics, + saving_listeners=saving_listeners) fit_args = ex.train(delay_secs=0) self.assertEqual(1, est.fit_count) self.assertIn(('max_steps', 'train_steps'), fit_args) @@ -487,6 +492,33 @@ class ExperimentTest(test.TestCase): self.assertEqual(3, est.eval_count) self.assertEqual([noop_hook], est.eval_hooks) + def test_continuous_eval_predicate_fn_with_checkpoint(self): + for est in self._estimators_for_tests(): + eval_metrics = 'eval_metrics' if not isinstance( + est, core_estimator.Estimator) else None + est.fake_checkpoint() + noop_hook = _NoopHook() + + def _predicate_fn(eval_result, checkpoint_path): + self.assertEqual(not eval_result, + checkpoint_path is None) + return est.eval_count < 3 # pylint: disable=cell-var-from-loop + + ex = experiment.Experiment( + est, + train_input_fn='train_input', + eval_input_fn='eval_input', + eval_metrics=eval_metrics, + eval_hooks=[noop_hook], + eval_delay_secs=0, + continuous_eval_throttle_secs=0) + ex.continuous_eval( + evaluate_checkpoint_only_once=False, + continuous_eval_predicate_fn=_predicate_fn) + self.assertEqual(0, est.fit_count) + self.assertEqual(3, est.eval_count) + self.assertEqual([noop_hook], est.eval_hooks) + def test_run_local(self): for est in self._estimators_for_tests(): eval_metrics = 'eval_metrics' if not isinstance( @@ -675,8 +707,12 @@ class ExperimentTest(test.TestCase): def test_continuous_train_and_eval(self): for est in self._estimators_for_tests(eval_dict={'global_step': 100}): - eval_metrics = 'eval_metrics' if not isinstance( - est, core_estimator.Estimator) else None + if isinstance(est, core_estimator.Estimator): + eval_metrics = None + saving_listeners = 'saving_listeners' + else: + eval_metrics = 'eval_metrics' + saving_listeners = None noop_hook = _NoopHook() export_strategy = saved_model_export_utils.make_export_strategy( est, @@ -690,7 +726,8 @@ class ExperimentTest(test.TestCase): eval_hooks=[noop_hook], train_steps=100, eval_steps=100, - export_strategies=export_strategy) + export_strategies=export_strategy, + saving_listeners=saving_listeners) ex.continuous_train_and_eval() self.assertEqual(1, est.fit_count) self.assertEqual(1, est.eval_count) @@ -742,9 +779,10 @@ class ExperimentTest(test.TestCase): ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn) mock_estimator.train.assert_called_once_with( input_fn='train_input', - steps=int(total_steps/10), + steps=int(total_steps / 10), max_steps=test.mock.ANY, - hooks=test.mock.ANY) + hooks=test.mock.ANY, + saving_listeners=test.mock.ANY) def test_continuous_train_and_eval_with_steps_per_iteration_from_user(self): mock_estimator = test.mock.Mock(core_estimator.Estimator) @@ -768,7 +806,8 @@ class ExperimentTest(test.TestCase): input_fn='train_input', steps=1234, max_steps=test.mock.ANY, - hooks=test.mock.ANY) + hooks=test.mock.ANY, + saving_listeners=test.mock.ANY) def test_continuous_train_and_eval_with_default_steps_per_iteration(self): mock_estimator = test.mock.Mock(core_estimator.Estimator) @@ -791,7 +830,8 @@ class ExperimentTest(test.TestCase): input_fn='train_input', steps=1000, max_steps=test.mock.ANY, - hooks=test.mock.ANY) + hooks=test.mock.ANY, + saving_listeners=test.mock.ANY) def test_continuous_train_and_eval_with_invalid_predicate_fn(self): for est in self._estimators_for_tests(): @@ -857,11 +897,19 @@ class ExperimentTest(test.TestCase): est, None if isinstance(est, core_estimator.Estimator) else 'export_input', exports_to_keep=None) + if isinstance(est, core_estimator.Estimator): + eval_metrics = None + saving_listeners = 'saving_listeners' + else: + eval_metrics = 'eval_metrics' + saving_listeners = None ex = experiment.Experiment( est, train_input_fn='train_input', eval_input_fn='eval_input', - export_strategies=(exp_strategy,)) + export_strategies=(exp_strategy,), + eval_metrics=eval_metrics, + saving_listeners=saving_listeners) ex.test() self.assertEqual(1, est.fit_count) self.assertEqual(1, est.eval_count) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index db18ebf05d5fb98e28e767be7bcccdf992a56fd8..86fad4c5535a918d87e0741687cfebe3afaf9ddf 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -28,7 +28,6 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -369,10 +368,11 @@ class DataFeeder(object): if x_is_dict: num_samples = list(self._x.values())[0].shape[0] elif tensor_util.is_tensor(self._x): - num_samples = self._x.shape[0].value # shape will be a Dimension, extract an int + num_samples = self._x.shape[ + 0].value # shape will be a Dimension, extract an int else: num_samples = self._x.shape[0] - + if self._shuffle: self.indices = self.random_state.permutation(num_samples) else: diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 4b34fc62849766370979bb2002d42ee03ea7161a..3a46c239688017f9204d2c6182a6f81cd325a417 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import io_ops @@ -280,14 +281,33 @@ def _get_file_names(file_pattern, randomize_input): def _get_examples(file_name_queue, reader, num_threads, read_batch_size, filter_fn, parse_fn): + """Get example filenames matching. + + Args: + file_name_queue: A queue implementation that dequeues elements in + first-in first-out order. + reader: A function or class that returns an object with + `read` method, (filename tensor) -> (example tensor). + num_threads: The number of threads enqueuing examples. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. + filter_fn: Filtering function, takes both keys as well as an `Example` + Tensors and returns a boolean mask of the same shape as the input Tensors + to be applied for filtering. If `None`, no filtering is done. + parse_fn: Parsing function, takes `Example` Tensor returns parsed + representation. If `None`, no parsing is done. + + Returns: + List of example file names matching `file_name_queue`. + """ with ops.name_scope('read'): example_list = [] for _ in range(num_threads): - if read_batch_size > 1: - keys, examples_proto = reader().read_up_to(file_name_queue, - read_batch_size) - else: - keys, examples_proto = reader().read(file_name_queue) + keys, examples_proto = utils.smart_cond( + read_batch_size > 1, + lambda: reader().read_up_to(file_name_queue, read_batch_size), + lambda: reader().read(file_name_queue)) + if filter_fn: mask = filter_fn(keys, examples_proto) keys = array_ops.boolean_mask(keys, mask) @@ -379,14 +399,15 @@ def _read_keyed_batch_examples_helper(file_pattern, capacity=1, dtypes=[dtypes.string], shapes=[[]]) enqueue_op = file_name_queue.enqueue( input_pipeline_ops.seek_next( - file_names, shuffle=randomize_input, num_epochs=num_epochs, + file_names, + shuffle=randomize_input, + num_epochs=num_epochs, seed=seed)) queue_runner.add_queue_runner( queue_runner.QueueRunner(file_name_queue, [enqueue_op])) else: file_name_queue = input_ops.string_input_producer( - constant_op.constant( - file_names, name='input'), + constant_op.constant(file_names, name='input'), shuffle=randomize_input, num_epochs=num_epochs, name=file_name_queue_scope, @@ -496,7 +517,8 @@ def read_keyed_batch_features(file_pattern, """ with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope: - if read_batch_size is None: read_batch_size = batch_size + if read_batch_size is None: + read_batch_size = batch_size keys, examples = read_keyed_batch_examples( file_pattern, batch_size, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index 6f0fd9a2976d37d1c701a96f50c2b987562cb191..e11e8b698adc113486bbb45572c8129e964cc931 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -204,8 +204,7 @@ class GraphIOTest(test.TestCase): shape = (0,) features = { "feature": - parsing_ops.FixedLenFeature( - shape=shape, dtype=dtypes_lib.float32) + parsing_ops.FixedLenFeature(shape=shape, dtype=dtypes_lib.float32) } with ops.Graph().as_default() as g, self.test_session(graph=g) as sess: @@ -255,8 +254,8 @@ class GraphIOTest(test.TestCase): self.assertAllEqual((None,), inputs.get_shape().as_list()) self.assertEqual("%s:1" % name, inputs.name) file_name_queue_name = "%s/file_name_queue" % name - file_name_queue_limit_name = ("%s/limit_epochs/epochs" % - file_name_queue_name) + file_name_queue_limit_name = ( + "%s/limit_epochs/epochs" % file_name_queue_name) file_names_name = "%s/input" % file_name_queue_name example_queue_name = "%s/random_shuffle_queue" % name op_nodes = test_util.assert_ops_in_graph({ @@ -354,8 +353,8 @@ class GraphIOTest(test.TestCase): json_lines = [ "".join([ '{"features": { "feature": { "sequence": {', - '"bytes_list": { "value": ["', base64.b64encode(l).decode("ascii"), - '"]}}}}}\n' + '"bytes_list": { "value": ["', + base64.b64encode(l).decode("ascii"), '"]}}}}}\n' ]) for l in lines ] return self._create_temp_file("".join(json_lines)) @@ -823,6 +822,31 @@ class GraphIOTest(test.TestCase): coord.request_stop() coord.join(threads) + def test_read_keyed_batch_features_shared_queue(self): + batch_size = 17 + shape = (0,) + fixed_feature = parsing_ops.FixedLenFeature( + shape=shape, dtype=dtypes_lib.float32) + feature = {"feature": fixed_feature} + reader = io_ops.TFRecordReader + + _, queued_feature = graph_io.read_keyed_batch_features_shared_queue( + _VALID_FILE_PATTERN, batch_size, feature, reader) + + with ops.Graph().as_default() as g, self.test_session(graph=g) as session: + features_result = graph_io.read_batch_features( + _VALID_FILE_PATTERN, batch_size, feature, reader) + session.run(variables.local_variables_initializer()) + + self.assertAllEqual( + queued_feature.get("feature").get_shape().as_list(), + features_result.get("feature").get_shape().as_list()) + + def test_get_file_names_errors(self): + # Raise bad file_pattern. + with self.assertRaises(ValueError): + graph_io._get_file_names([], True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py index ed6683abedbb8ae76ba364405158eb52cbb6d762..6440bc204b8e339ff51311dcc87b36f556b94092 100644 --- a/tensorflow/contrib/learn/python/learn/metric_spec.py +++ b/tensorflow/contrib/learn/python/learn/metric_spec.py @@ -42,10 +42,8 @@ def _args(fn): """ if hasattr(fn, 'func') and hasattr(fn, 'keywords'): # Handle functools.partial and similar objects. - return tuple([ - arg for arg in tf_inspect.getargspec(fn.func).args - if arg not in set(fn.keywords.keys()) - ]) + return tuple( + [arg for arg in _args(fn.func) if arg not in set(fn.keywords.keys())]) # Handle function. return tuple(tf_inspect.getargspec(fn).args) diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 6af2287761299f6725f9547917101c18b0cc0164..cb34cb1d26b6812c7f3f39e9f965615de5a8ef07 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -20,7 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework import deprecated -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import gc from tensorflow.python.client import session as tf_session @@ -78,7 +78,7 @@ def _export_graph(graph, saver, checkpoint_path, export_dir, default_graph_signature=default_graph_signature, named_graph_signatures=named_graph_signatures, assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)) - return export.export(export_dir, contrib_variables.get_global_step(), + return export.export(export_dir, training_util.get_global_step(), session, exports_to_keep=exports_to_keep) @@ -295,7 +295,7 @@ def _export_estimator(estimator, checkpoint_path = (checkpoint_path or tf_saver.latest_checkpoint(estimator._model_dir)) with ops.Graph().as_default() as g: - contrib_variables.create_global_step(g) + training_util.create_global_step(g) if use_deprecated_input_fn: examples = array_ops.placeholder(dtype=dtypes.string, diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 49413092a6bae547ddd2cad272b1abb3af1de046..6ffd2a133995a6ff8b35540221fb5676bf5de19f 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -33,6 +33,7 @@ from __future__ import division from __future__ import print_function import os +import tempfile import time from tensorflow.contrib.layers.python.layers import feature_column @@ -644,18 +645,22 @@ def make_best_model_export_strategy(serving_input_fn, # TODO(b/67013778): Revisit this approach when corresponding changes to # TF Core are finalized. -def extend_export_strategy(base_export_strategy, post_export_fn, - post_export_name): +def extend_export_strategy(base_export_strategy, + post_export_fn, + post_export_name=None): """Extend ExportStrategy, calling post_export_fn after export. Args: base_export_strategy: An ExportStrategy that can be passed to the Experiment constructor. post_export_fn: A user-specified function to call after exporting the - SavedModel. Takes the export directory as an argument, and returns - a string path to a (potentially different) SavedModel. + SavedModel. Takes two arguments - the path to the SavedModel exported by + base_export_strategy and the directory where to export the SavedModel + modified by the post_export_fn. Returns the path to the exported + SavedModel. post_export_name: The directory name under the export base directory where - SavedModels generated by the post_export_fn will be written. + SavedModels generated by the post_export_fn will be written. If None, the + directory name of base_export_strategy is used. Returns: An ExportStrategy that can be passed to the Experiment constructor. @@ -675,12 +680,24 @@ def extend_export_strategy(base_export_strategy, post_export_fn, Raises: ValueError: If `estimator` is a ${tf.estimator.Estimator} instance - and `default_output_alternative_key` was specified. + and `default_output_alternative_key` was specified or if post_export_fn + does not return a valid directory. """ - export_dir = base_export_strategy.export(estimator, export_dir_base, - checkpoint_path) - if post_export_fn: - export_dir = post_export_fn(export_dir) - return export_dir - - return export_strategy.ExportStrategy(post_export_name, export_fn) + tmp_base_export_dir = tempfile.mkdtemp() + tmp_base_export = base_export_strategy.export( + estimator, tmp_base_export_dir, checkpoint_path) + tmp_post_export_dir = tempfile.mkdtemp() + tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir) + + if not tmp_post_export.startswith(tmp_post_export_dir): + raise ValueError('post_export_fn must return a sub-directory of {}' + .format(tmp_post_export_dir)) + export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) + + gfile.Rename( + os.path.join(tmp_post_export_dir, export_relpath), + os.path.join(export_dir_base, export_relpath)) + return os.path.join(export_dir_base, export_relpath) + + name = post_export_name if post_export_name else base_export_strategy.name + return export_strategy.ExportStrategy(name, export_fn) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index 27f17b54221ea442baafb382aa3fb034d1bb82e6..ec3a88003f01b3b62591c13472029601b11ba491 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -743,12 +743,19 @@ class SavedModelExportUtilsTest(test.TestCase): None) def test_extend_export_strategy(self): - def _base_export_fn(unused_estimator, export_dir_base, + + def _base_export_fn(unused_estimator, + export_dir_base, unused_checkpoint_path=None): - return export_dir_base + "/e1" + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path - def _post_export_fn(orig_path): - return orig_path + "/rewrite" + def _post_export_fn(orig_path, new_path): + assert orig_path.endswith("/e1") + post_export_path = os.path.join(new_path, "rewrite") + gfile.MkDir(post_export_path) + return post_export_path base_export_strategy = export_strategy_lib.ExportStrategy( "Servo", _base_export_fn) @@ -758,9 +765,67 @@ class SavedModelExportUtilsTest(test.TestCase): self.assertEqual(final_export_strategy.name, "Servo2") test_estimator = TestEstimator() - final_path = final_export_strategy.export(test_estimator, "/path/to/orig", - "/path/to/checkpoint") - self.assertEqual("/path/to/orig/e1/rewrite", final_path) + tmpdir = tempfile.mkdtemp() + final_path = final_export_strategy.export(test_estimator, tmpdir, + os.path.join( + tmpdir, "checkpoint")) + self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + + def test_extend_export_strategy_same_name(self): + + def _base_export_fn(unused_estimator, + export_dir_base, + unused_checkpoint_path=None): + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path + + def _post_export_fn(orig_path, new_path): + assert orig_path.endswith("/e1") + post_export_path = os.path.join(new_path, "rewrite") + gfile.MkDir(post_export_path) + return post_export_path + + base_export_strategy = export_strategy_lib.ExportStrategy( + "Servo", _base_export_fn) + + final_export_strategy = saved_model_export_utils.extend_export_strategy( + base_export_strategy, _post_export_fn) + self.assertEqual(final_export_strategy.name, "Servo") + + test_estimator = TestEstimator() + tmpdir = tempfile.mkdtemp() + final_path = final_export_strategy.export(test_estimator, tmpdir, + os.path.join( + tmpdir, "checkpoint")) + self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + + def test_extend_export_strategy_raises_error(self): + + def _base_export_fn(unused_estimator, + export_dir_base, + unused_checkpoint_path=None): + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path + + def _post_export_fn(unused_orig_path, unused_new_path): + return tempfile.mkdtemp() + + base_export_strategy = export_strategy_lib.ExportStrategy( + "Servo", _base_export_fn) + + final_export_strategy = saved_model_export_utils.extend_export_strategy( + base_export_strategy, _post_export_fn) + + test_estimator = TestEstimator() + tmpdir = tempfile.mkdtemp() + with self.assertRaises(ValueError) as ve: + final_export_strategy.export(test_estimator, tmpdir, + os.path.join(tmpdir, "checkpoint")) + + self.assertTrue( + "post_export_fn must return a sub-directory" in str(ve.exception)) def _create_test_export_dir(export_dir_base): diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index 86d848439191aeb0dfa88bbe0fb9b3b654499423..7526f3ae0dbdb3d6827e9d7f690090b8438e4f6e 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -251,8 +251,9 @@ class SdcaModel(object): result_dense = 0.0 for i in range(len(dense_variables)): - result_dense += math_ops.matmul( - dense_features[i], array_ops.expand_dims(dense_variables[i], -1)) + result_dense += math_ops.matmul(dense_features[i], + array_ops.expand_dims( + dense_variables[i], -1)) # Reshaping to allow shape inference at graph construction time. return array_ops.reshape(result_dense, [-1]) + result_sparse diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index 701fc1c0597d1de0b0189e86feafbd1c5bbdc818..05794a42c5f2d0eece6adab36fb5610078cece31 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import prediction_key @@ -154,7 +154,7 @@ def sdca_model_fn(features, labels, mode, params, config=None): _add_bias_column(feature_columns, features, bias, columns_to_variables) def _train_op_fn(unused_loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() sdca_model, train_op = optimizer.get_train_step( columns_to_variables, weight_column_name, loss_type, features, labels, global_step) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3f1b0be1a73a3ff1da3452f4ee1a9125f9e26178 --- /dev/null +++ b/tensorflow/contrib/lite/BUILD @@ -0,0 +1,204 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") + +exports_files(glob([ + "testdata/*.bin", + "models/testdata/*", +])) + +config_setting( + name = "mips", + values = { + "cpu": "mips", + }, +) + +config_setting( + name = "mips64", + values = { + "cpu": "mips64", + }, +) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "schema_fbs_version", + hdrs = ["version.h"], +) + +# Main library. No ops are included here. +# TODO(aselle): Resolve problems preventing C99 usage. +cc_library( + name = "context", + srcs = ["context.c"], + hdrs = ["context.h"], +) + +cc_library( + name = "builtin_op_data", + hdrs = [ + "builtin_op_data.h", + ], +) + +cc_library( + name = "string", + hdrs = [ + "string.h", + ], + deps = [ + "//tensorflow/core:lib_platform", + ], +) + +# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. +cc_library( + name = "framework", + srcs = [ + "allocation.cc", + "error_reporter.cc", + "interpreter.cc", + "model.cc", + "nnapi_delegate.cc", + "optional_debug_tools.cc", + "simple_memory_arena.cc", + ], + hdrs = [ + "allocation.h", + "context.h", + "error_reporter.h", + "interpreter.h", + "model.h", + "nnapi_delegate.h", + "optional_debug_tools.h", + "simple_memory_arena.h", + ], + copts = tflite_copts(), + deps = [ + ":builtin_op_data", + ":context", + ":schema_fbs_version", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/core:lib_platform", + ], +) + +cc_library( + name = "string_util", + srcs = ["string_util.cc"], + hdrs = ["string_util.h"], + deps = [ + ":framework", + ":string", + ], +) + +cc_test( + name = "string_util_test", + size = "small", + srcs = ["string_util_test.cc"], + deps = [ + ":framework", + ":string_util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test main interpreter +cc_test( + name = "interpreter_test", + size = "small", + srcs = ["interpreter_test.cc"], + deps = [ + ":framework", + ":string_util", + "@com_google_googletest//:gtest", + ], +) + +# Test arena allocator +cc_test( + name = "simple_memory_arena_test", + size = "small", + srcs = ["simple_memory_arena_test.cc"], + deps = [ + ":framework", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test model framework. +cc_test( + name = "model_test", + size = "small", + srcs = ["model_test.cc"], + data = [ + "testdata/0_subgraphs.bin", + "testdata/2_subgraphs.bin", + "testdata/empty_model.bin", + "testdata/test_model.bin", + "testdata/test_model_broken.bin", + ], + deps = [ + ":framework", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test the C extension API code. +cc_test( + name = "context_test", + size = "small", + srcs = ["context_test.cc"], + deps = [ + ":framework", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test the serialization of a model with optional tensors. + +# Model tests + +cc_library( + name = "models_test_utils", + testonly = 1, + hdrs = ["models/test_utils.h"], + deps = select({ + "//tensorflow:android": [], + "//conditions:default": [ + "@com_google_absl//absl/strings", + "//tensorflow/core:test", + ], + }), +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + "downloads", + "examples", + "gen", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..78402727abdd2742ffff54bf59ca076d8b97b042 --- /dev/null +++ b/tensorflow/contrib/lite/Makefile @@ -0,0 +1,147 @@ + +# Find where we're running from, so we can store generated files here. +ifeq ($(origin MAKEFILE_DIR), undefined) + MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) +endif + +# Try to figure out the host system +HOST_OS := +ifeq ($(OS),Windows_NT) + HOST_OS = WINDOWS +else + UNAME_S := $(shell uname -s) + ifeq ($(UNAME_S),Linux) + HOST_OS := LINUX + endif + ifeq ($(UNAME_S),Darwin) + HOST_OS := OSX + endif +endif + +ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) + +# Where compiled objects are stored. +OBJDIR := $(MAKEFILE_DIR)/gen/obj/ +BINDIR := $(MAKEFILE_DIR)/gen/bin/ +LIBDIR := $(MAKEFILE_DIR)/gen/lib/ +GENDIR := $(MAKEFILE_DIR)/gen/obj/ + +# Settings for the host compiler. +CXX := $(CC_PREFIX) gcc +CXXFLAGS := --std=c++11 -O3 -DNDEBUG +CC := $(CC_PREFIX) gcc +CFLAGS := +LDOPTS := +LDOPTS += -L/usr/local/lib +ARFLAGS := -r + +INCLUDES := \ +-I. \ +-I$(MAKEFILE_DIR)/../../../ \ +-I$(MAKEFILE_DIR)/downloads/ \ +-I$(MAKEFILE_DIR)/downloads/eigen \ +-I$(MAKEFILE_DIR)/downloads/gemmlowp \ +-I$(MAKEFILE_DIR)/downloads/neon_2_sse \ +-I$(MAKEFILE_DIR)/downloads/farmhash/src \ +-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ +-I$(GENDIR) +# This is at the end so any globally-installed frameworks like protobuf don't +# override local versions in the source tree. +INCLUDES += -I/usr/local/include + +LIBS := \ +-lstdc++ \ +-lpthread \ +-lm \ +-lz + +# If we're on Linux, also link in the dl library. +ifeq ($(OS),LINUX) + LIBS += -ldl -lpthread +endif + +include $(MAKEFILE_DIR)/ios_makefile.inc + +# This library is the main target for this makefile. It will contain a minimal +# runtime that can be linked in to other programs. +LIB_NAME := libtensorflow-lite.a +LIB_PATH := $(LIBDIR)$(LIB_NAME) + +# A small example program that shows how to link against the library. +BENCHMARK_PATH := $(BINDIR)benchmark_model + +BENCHMARK_SRCS := \ +tensorflow/contrib/lite/tools/benchmark_model.cc +BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) + +# What sources we want to compile, must be kept in sync with the main Bazel +# build files. + +CORE_CC_ALL_SRCS := \ +$(wildcard tensorflow/contrib/lite/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \ +$(wildcard tensorflow/contrib/lite/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ +$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) +# Remove any duplicates. +CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) +CORE_CC_EXCLUDE_SRCS := \ +$(wildcard tensorflow/contrib/lite/*test.cc) \ +$(wildcard tensorflow/contrib/lite/*/*test.cc) \ +$(wildcard tensorflow/contrib/lite/*/*/*test.cc) \ +$(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ +$(BENCHMARK_SRCS) +# Filter out all the excluded files. +TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) +# File names of the intermediate files target compilation generates. +TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) +LIB_OBJS := $(TF_LITE_CC_OBJS) + +# For normal manually-created TensorFlow C++ source files. +$(OBJDIR)%.o: %.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ + +# For normal manually-created TensorFlow C++ source files. +$(OBJDIR)%.o: %.c + @mkdir -p $(dir $@) + $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ + +# The target that's compiled if there's no command-line arguments. +all: $(LIB_PATH) $(BENCHMARK_PATH) + +# Gathers together all the objects we've compiled into a single '.a' archive. +$(LIB_PATH): $(LIB_OBJS) + @mkdir -p $(dir $@) + $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) + +$(BENCHMARK_PATH): $(BENCHMARK_OBJS) $(LIB_PATH) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $(BENCHMARK_PATH) $(BENCHMARK_OBJS) \ + $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) + +# Gets rid of all generated files. +clean: + rm -rf $(MAKEFILE_DIR)/gen + +# Gets rid of target files only, leaving the host alone. Also leaves the lib +# directory untouched deliberately, so we can persist multiple architectures +# across builds for iOS and Android. +cleantarget: + rm -rf $(OBJDIR) + rm -rf $(BINDIR) + +$(DEPDIR)/%.d: ; +.PRECIOUS: $(DEPDIR)/%.d + +-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(TF_CC_SRCS))) diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2fb40070cb25df16d32569ca764c181bf6333506 --- /dev/null +++ b/tensorflow/contrib/lite/README.md @@ -0,0 +1,222 @@ +# TensorFlow Lite +TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration. + +TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device. + +![image](g3doc/TFLite-Architecture.jpg) +# Getting Started with an Android Demo App + +This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo. + +There are 3 ways to get the demo app to your device + - Download the prebuilt binary or + - Use Android Studio to build the application or + - Download the source code for TensorFlow Lite and the demo and build it using bazel + +## Description +In the demo app, inference is done using the TensorFlow Lite Java API. The demo app classifies frames in real-time, displaying the top most probable classifications. It also displays the time taken to detect the object. + +## Downloading the pre-built binary +The fastest path to trying the demo, is to download the pre-built binary +[TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) + +Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera's field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified. + +## Building in Android Studio using TensorFlow Lite AAR from JCenter +The simplest way to compile the demo app, and try out changes to the project code is to use AndroidStudio. + + - Install the latest version of Android Studio 3 as specified [here](https://developer.android.com/studio/index.html). + - Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings). + - Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project. + - Click through installing all the Gradle extensions it requests. + - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) + - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory: + `tensorflow/contrib/lite/java/demo/app/src/main/assets/` + - Build and run the demo app + +## Building TensorFlow Lite and the demo app from source + +### Clone the TensorFlow repo +- git clone + [https://github.com/tensorflow/tensorflow](https://github.com/tensorflow/tensorflow) + +### Install Bazel +If bazel is not installed on your system, install it now by following [these directions](https://bazel.build/versions/master/docs/install.html) + +NOTE: Bazel does not fully support building Android on Windows yet. Full support for Gradle/CMake builds is coming soon, but in the meantime Windows users should download the [prebuilt binary](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead. + +### Install Android NDK and SDK +Bazel is the primary build system for TensorFlow. Bazel and the Android NDK and SDK must be installed on your system. + - Install the latest version of Bazel as per the instructions on the [Bazel website](https://bazel.build/versions/master/docs/install.html) + - The Android NDK is required to build the native (C/C++) TensorFlow Lite code. The current recommended version is 14b, which can be found [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads). + - The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TF Android demo (though it will run on API >= 21 devices). + - In the root of the TensorFlow repository update the `WORKSPACE` file with the `api_level` and location of the SDK and NDK. If you installed it with AndroidStudio the SDK path can be found in the SDK manager, and the default NDK path is:`{SDK path}/ndk-bundle.` + +``` +android_sdk_repository ( + name = "androidsdk", + api_level = 23, + build_tools_version = "23.0.2", + path = "/home/xxxx/android-sdk-linux/", +) + +android_ndk_repository( + name = "androidndk", + path = "/home/xxxx/android-ndk-r10e/", + api_level = 19, +) +``` + +Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). + +### Build the source code +Run bazel with the following command to build the demo. + +Build the demo app: + +``` +bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo +``` + +### Note + +Currently, we only support building the Android demo app within a Python 2 +environment (due to a Bazel bug). + +### More about the demo +The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app. + +# iOS Demo App + +Similar to the Android demo app, there's an iOS camera app that uses exactly the same model (224 * 224 quantized Mobilenet). + +This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app: + +1. Follow the Building section [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md#building) to build the universal iOS library for TensorFlow Lite. +1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`. +1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file. +1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode. + +# TensorFlow Lite Quick Start + +## Step 1. Decide which GraphDef to use + Depending on the use case, the developer may choose to use one of the popular + open-sourced models such as InceptionV3 or MobileNets, re-train these models + with their own custom data set or even build their own custom model. + +### Using a pre-trained model + +[MobileNets](https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html) is a family of mobile-first computer vision models for [TensorFlow](https://www.tensorflow.org/) designed to effectively maximize accuracy while being mindful of the restricted resources for an on-device or embedded application. MobileNets are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as [Inception](https://arxiv.org/pdf/1602.07261.pdf), are used. Google provides 16 pre-trained [ImageNet](http://www.image-net.org/challenges/LSVRC/) classification checkpoints for MobileNets for use in mobile projects of all sizes. + +[Inception-v3](https://arxiv.org/abs/1512.00567) is an image recognition model which achieves fairly high accuracy in recognizing general objects with 1000 classes, like "Zebra", "Dalmatian", and "Dishwasher". The model extracts general features from input images using a convolutional neural network and classifies them based on those features with fully-connected and softmax layers. + +[On Device Smart Reply](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) is an on-device model which provides one-touch replies for an incoming text message by suggesting contextually relevant messages. The model is built specifically for memory constrained devices such as watches & phones and it has been successfully used to surface [Smart Replies on Android Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html). Note that this model only works on Android as of now. + +These pre-trained models can be downloaded from [here](g3doc/models.md). + +### Retrain Inception-V3 or MobileNet for a custom data set +The above pre-trained models have been trained on the ImageNet data set, which consists of 1000 predefined classes. A model will need to be re-trained if these classes are not relevant or useful for a given use case. This technique is called transfer learning, which starts with a model that has been already trained on a problem and will then be retrained on a similar problem. Deep learning from scratch can take days, but transfer learning can be done fairly quickly. In order to do this, a developer will need to generate their custom data set labeled with the relevant classes. + +The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/) codelab walks through this process step-by-step. The retraining code supports retraining for both floating point and quantized inference. + + +### Train a custom model +A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow's Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model. + +TensorFlow Lite currently supports a subset of TensorFlow operators. Please refer to [this document](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for details of supported operators and their usage. This +set will continue to expand in future releases of Tensorflow Lite. + + +## Step 2. Model format conversion + +The model generated in Step 1 is a standard Tensorflow model. After the completion of Step 1 a user should have a standard .pb or .pbtxt GraphDef file. If the application developer is using a pre-trained model (as defined in Step 1 above), they can download a ready to use, already converted model for use from [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/models.md). Models generated using retraining (aka transfer learning) or custom models will need to be converted using the steps mentioned below. + +A prerequisite to converting the model to the Tensorflow Lite format is to freeze the graph. + +Since we employ several formats, the following definitions may be useful: + - GraphDef (.pb) - a protobuf that represents the TensorFlow training and or computation graph. This contains operators, tensors, and variables definitions. + + - CheckPoint (.ckpt) - Serialized variables from a TensorFlow graph. Note, this does not contain the graph structure, so alone it cannot typically be interpreted. + + - FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint. + + - SavedModel - A collection of GraphDef and CheckPoint together with a signature that labels input and output arguments to a model. A GraphDef and Checkpoint can be extracted from a saved model. + + - TensorFlow lite model (.lite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs. + +### Freeze Graph +To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as "freezing" the graph. + +The developer should know where the checkpoints folder is present or checkpoints can also be downloaded for a pre-trained model (Example: Here is a link to the [MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)). + +Graph freezing can be done using the command below (and modifying the arguments appropriately) + +``` +bazel build tensorflow/python/tools:freeze_graph + +bazel-bin/tensorflow/python/tools/freeze_graph\ + --input_graph=/tmp/mobilenet_v1_224.pb \ + --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \ + --input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb \ + --output_node_names=MobileNet/Predictions/Reshape_1 +``` + +The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with +graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/#3). + +This frozen Graphdef is now ready to be converted to flatbuffer format (.lite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. + +Here is a sample command line to convert the frozen Graphdef to '.lite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. +(Here is a link to the pb [file](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz)). + +``` +bazel build tensorflow/contrib/lite/toco:toco + +bazel-bin/tensorflow/contrib/lite/toco/toco -- \ + --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \ + --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ + --output_file=/tmp/mobilenet_v1_1.0_224.lite --inference_type=FLOAT \ + --input_type=FLOAT --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3 +``` + +- The input_file argument should point to the frozen GraphDef file that holds the model architecture. +- The output_file argument should point to where the TensorFlow Lite model file should be generated. +- The input_type and inference_type arguments should be set to FLOAT, unless converted a [quantized](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/) model. +- Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in tensorboard . The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step. + +Note, it is also possible to use the Tensorflow Optimizing Converter through protos either from Python or from the command line see the +documentation [here](https://github.com/tensorflow/tensorflow/tree/mastertensorflow/contrib/lite/python:toco_from_protos target) A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, + +``` +import tensorflow as tf + +img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) +val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +out = tf.identity(val, name="out") +with tf.Session() as sess: + tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) + open("converteds_model.tflite", "wb").write(tflite_model) + +``` +For detailed instructions on how to use the Tensorflow Optimizing Converter, please see [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md). + +You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help. If that doesn't help, please file an [issue](https://github.com/tensorflow/tensorflow/issues). + +## Step 3. Use the TensorFlow Lite model for inference in a mobile app + +After completion of Step 2 the developer should have a .lite model. + +### For Android +Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). + +The [demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app) uses this interface, so it's a good place to look for example usage. You can also download the prebuilt binary [here](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). + +Note that you'd need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build). + +### For iOS +Follow the documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app. + +## Core ML support + +Core ML is a machine learning framework used across Apple products. In addition to using Tensorflow Lite models directly in their applications, developers have the option to convert their trained Tensorflow models to the [CoreML](https://developer.apple.com/machine-learning/) format for use on Apple devices. For information on how to use the converter please refer to the [Tensorflow-CoreML converter documentation](https://github.com/tf-coreml/tf-coreml). diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b322e027d48f4bf9f90d5b873c449d1ec31cc49 --- /dev/null +++ b/tensorflow/contrib/lite/allocation.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" + +namespace tflite { + +MMAPAllocation::MMAPAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) { + mmap_fd_ = open(filename, O_RDONLY); + if (mmap_fd_ == -1) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + struct stat sb; + fstat(mmap_fd_, &sb); + buffer_size_bytes_ = sb.st_size; + mmapped_buffer_ = + mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0); + if (mmapped_buffer_ == MAP_FAILED) { + error_reporter_->Report("Mmap of '%s' failed.", filename); + return; + } +} + +MMAPAllocation::~MMAPAllocation() { + if (valid()) { + munmap(const_cast(mmapped_buffer_), buffer_size_bytes_); + } + if (mmap_fd_ != -1) close(mmap_fd_); +} + +const void* MMAPAllocation::base() const { return mmapped_buffer_; } + +size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; } + +bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; } + +FileCopyAllocation::FileCopyAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter) { + // Obtain the file size, using an alternative method that is does not + // require fstat for more compatibility. + std::unique_ptr file(fopen(filename, "rb"), fclose); + if (!file) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + // TODO(ahentz): Why did you think using fseek here was better for finding + // the size? + struct stat sb; + if (fstat(fileno(file.get()), &sb) != 0) { + error_reporter_->Report("Failed to get file size of '%s'.", filename); + return; + } + buffer_size_bytes_ = sb.st_size; + std::unique_ptr buffer(new char[buffer_size_bytes_]); + if (!buffer) { + error_reporter_->Report("Malloc of buffer to hold copy of '%s' failed.", + filename); + return; + } + size_t bytes_read = + fread(buffer.get(), sizeof(char), buffer_size_bytes_, file.get()); + if (bytes_read != buffer_size_bytes_) { + error_reporter_->Report("Read of '%s' failed (too few bytes read).", + filename); + return; + } + copied_buffer_ = std::move(buffer); +} + +FileCopyAllocation::~FileCopyAllocation() {} + +const void* FileCopyAllocation::base() const { return copied_buffer_.get(); } + +size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; } + +bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; } + +MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter) + : Allocation(error_reporter) { + buffer_ = ptr; + buffer_size_bytes_ = num_bytes; +} + +MemoryAllocation::~MemoryAllocation() {} + +const void* MemoryAllocation::base() const { return buffer_; } + +size_t MemoryAllocation::bytes() const { return buffer_size_bytes_; } + +bool MemoryAllocation::valid() const { return buffer_ != nullptr; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h new file mode 100644 index 0000000000000000000000000000000000000000..ee8a7ccd0b232f9e48095567fd4aefe94f595bc3 --- /dev/null +++ b/tensorflow/contrib/lite/allocation.h @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Main abstraction controlling the tflite interpreter. +// See context.h for the API for defining operations (TfLiteRegistration). +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +namespace tflite { + +// A memory allocation handle. This could be a mmap or shared memory. +class Allocation { + public: + Allocation(ErrorReporter* error_reporter) : error_reporter_(error_reporter) {} + virtual ~Allocation() {} + + // Base pointer of this allocation + virtual const void* base() const = 0; + // Size in bytes of the allocation + virtual size_t bytes() const = 0; + // Whether the allocation is valid + virtual bool valid() const = 0; + + protected: + ErrorReporter* error_reporter_; +}; + +class MMAPAllocation : public Allocation { + public: + MMAPAllocation(const char* filename, ErrorReporter* error_reporter); + virtual ~MMAPAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + protected: + // Data required for mmap. + int mmap_fd_ = -1; // mmap file descriptor + const void* mmapped_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class FileCopyAllocation : public Allocation { + public: + FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); + virtual ~FileCopyAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + // Data required for mmap. + std::unique_ptr copied_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class MemoryAllocation : public Allocation { + public: + // Allocates memory with the pointer and the number of bytes of the memory. + // The pointer has to remain alive and unchanged until the destructor is + // called. + MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter); + virtual ~MemoryAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + const void* buffer_; + size_t buffer_size_bytes_ = 0; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl new file mode 100644 index 0000000000000000000000000000000000000000..d1fcdce70a34393defce0f2d0f6d5bb53f21c45e --- /dev/null +++ b/tensorflow/contrib/lite/build_def.bzl @@ -0,0 +1,235 @@ +"""Generate Flatbuffer binary from json.""" + +def tflite_copts(): + """Defines compile time flags.""" + copts = [ + "-DFARMHASH_NO_CXX_STRING", + ] + select({ + "//tensorflow:android_arm64": [ + "-std=c++11", + "-O3", + ], + "//tensorflow:android_arm": [ + "-mfpu=neon", + "-mfloat-abi=softfp", + "-std=c++11", + "-O3", + ], + "//tensorflow:android_x86": [ + "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", + ], + "//tensorflow:ios_x86_64": [ + "-msse4.1", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow:with_default_optimizations": [], + "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"], + }) + + return copts + +LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds" + +def tflite_linkopts_unstripped(): + """Defines linker flags to reduce size of TFLite binary. + + These are useful when trying to investigate the relative size of the + symbols in TFLite. + + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. + "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export. + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) + +def tflite_jni_linkopts_unstripped(): + """Defines linker flags to reduce size of TFLite binary with JNI. + + These are useful when trying to investigate the relative size of the + symbols in TFLite. + + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) + +def tflite_linkopts(): + """Defines linker flags to reduce size of TFLite binary.""" + return tflite_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + ], + "//conditions:default": [], + }) + +def tflite_jni_linkopts(): + """Defines linker flags to reduce size of TFLite binary with JNI.""" + return tflite_jni_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + "-latomic", # Required for some uses of ISO C++11 in x86. + ], + "//conditions:default": [], + }) + + +def tflite_jni_binary(name, + copts=tflite_copts(), + linkopts=tflite_jni_linkopts(), + linkscript=LINKER_SCRIPT, + linkshared=1, + linkstatic=1, + deps=[]): + """Builds a jni binary for TFLite.""" + linkopts = linkopts + [ + "-Wl,--version-script", # Export only jni functions & classes. + linkscript, + ] + native.cc_binary( + name=name, + copts=copts, + linkshared=linkshared, + linkstatic=linkstatic, + deps= deps + [linkscript], + linkopts=linkopts) + +def tf_to_tflite(name, src, options, out): + """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. + + Args: + name: Name of rule. + src: name of the input graphdef file. + options: options passed to TOCO. + out: name of the output flatbuffer file. + """ + + toco = "//tensorflow/contrib/lite/toco:toco" + native.genrule( + name = name, + srcs=[src, options], + outs=[out], + cmd = ("$(location %s) " + + " --input_file=$(location %s) " + + " --output_file=$(location %s) " + + " --input_format=TENSORFLOW_GRAPHDEF" + + " --output_format=TFLITE" + + " `cat $(location %s)`") + % (toco, src, out, options), + tools= [toco], + ) + +def tflite_to_json(name, src, out): + """Convert a TF Lite flatbuffer to JSON. + + Args: + name: Name of rule. + src: name of the input flatbuffer file. + out: name of the output JSON file. + """ + + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema.fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" + + "$(location %s) --raw-binary --strict-json -t" + + " -o /tmp $(location %s) -- $${TMP}.bin &&" + + "cp $${TMP}.json $(location %s)") + % (src, flatc, schema, out), + tools = [flatc], + ) + +def json_to_tflite(name, src, out): + """Convert a JSON file to TF Lite's flatbuffer. + + Args: + name: Name of rule. + src: name of the input JSON file. + out: name of the output flatbuffer file. + """ + + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema_fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" + + "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" + + " -o /tmp $(location %s) $${TMP}.json &&" + + "cp $${TMP}.bin $(location %s)") + % (src, flatc, schema, out), + tools = [flatc], + ) + +def gen_zipped_test_files(name, files): + """Generate a zip file of tests by using :generate_examples. + + Args: + name: Name of output. We will produce "`name`_files" as a target. + files: A list of zip file basenames. + """ + toco = "//tensorflow/contrib/lite/toco:toco" + out_files = [] + for f in files: + out_file = name + "/" + f + out_files.append(out_file) + native.genrule( + name = name + "_" + f + ".files", + cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco + + " --zip_to_output " + f + + " $(@D) zipped"), + outs = [out_file], + tools = [ + ":generate_examples", + toco, + ], + ) + + native.filegroup( + name = name, + srcs = out_files, + ) + +def gen_selected_ops(name, model): + """Generate the library that includes only used ops. + + Args: + name: Name of the generated library. + model: TFLite model to interpret. + """ + out = name + "_registration.cc" + tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" + tflite_path = "//tensorflow/contrib/lite" + native.genrule( + name = name, + srcs = [model], + outs = [out], + cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") + % (tool, model, out, tflite_path[2:]), + tools = [tool], + ) diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh new file mode 100755 index 0000000000000000000000000000000000000000..cbc96e6edd4358f6666731caa4c208c77d9c6c54 --- /dev/null +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -0,0 +1,31 @@ +#!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 + +lipo \ +tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_i386/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_armv7/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_armv7s/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_arm64/libtensorflow-lite.a \ +-create \ +-output tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h new file mode 100644 index 0000000000000000000000000000000000000000..7249d124e982636a6639309f17c238cfbb02bf02 --- /dev/null +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -0,0 +1,172 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// TODO(aselle): Consider using "if this then that" for testing. + +// Possible padding types (for convolutions) +typedef enum { + kTfLitePaddingUnknown = 0, + kTfLitePaddingSame, + kTfLitePaddingValid, +} TfLitePadding; + +typedef struct { + int width; + int height; +} TfLitePaddingValues; + +// Possible fused activation functions. +// TODO(aselle): rename to TfLiteActivation +typedef enum { + kTfLiteActNone = 0, + kTfLiteActRelu, + kTfLiteActRelu1, + kTfLiteActRelu6, + kTfLiteActTanh, + kTfLiteActSignBit, + kTfLiteActSigmoid, +} TfLiteFusedActivation; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; +} TfLiteConvParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int filter_width; + int filter_height; + TfLiteFusedActivation activation; + struct { + TfLitePaddingValues padding; + } computed; +} TfLitePoolParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int depth_multiplier; + TfLiteFusedActivation activation; +} TfLiteDepthwiseConvParams; + +typedef struct { + int rank; + TfLiteFusedActivation activation; +} TfLiteSVDFParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteRNNParams; + +typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams; + +typedef enum { + kTfLiteLshProjectionUnknown = 0, + kTfLiteLshProjectionSparse = 1, + kTfLiteLshProjectionDense = 2, +} TfLiteLSHProjectionType; + +typedef struct { TfLiteLSHProjectionType type; } TfLiteLSHProjectionParams; + +typedef struct { float beta; } TfLiteSoftmaxParams; + +typedef struct { + int axis; + TfLiteFusedActivation activation; +} TfLiteConcatenationParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteAddParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteMulParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteL2NormParams; + +typedef struct { + int radius; + float bias; + float alpha; + float beta; +} TfLiteLocalResponseNormParams; + +typedef struct { + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; +} TfLiteLSTMParams; + +typedef struct { + int new_height; + int new_width; +} TfLiteResizeBilinearParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int before_padding[8]; + int after_padding[8]; + int num_dimensions; +} TfLitePadParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int shape[8]; + int num_dimensions; +} TfLiteReshapeParams; + +typedef struct { + int ngram_size; + int max_skip_size; + bool include_all_ngrams; +} TfLiteSkipGramParams; + +typedef struct { + int block_size; +} TfLiteSpaceToDepthParams; + +typedef enum { + kTfLiteCombinerTypeSum = 0, + kTfLiteCombinerTypeMean = 1, + kTfLiteCombinerTypeSqrtn = 2, +} TfLiteCombinerType; + +typedef struct { + TfLiteCombinerType combiner; +} TfLiteEmbeddingLookupSparseParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c new file mode 100644 index 0000000000000000000000000000000000000000..c09e838c5c2e50e0f4a38eaf66e55246fd9a6f7f --- /dev/null +++ b/tensorflow/contrib/lite/context.c @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/context.h" +#include +#include + +TfLiteIntArray* TfLiteIntArrayCreate(int size) { + TfLiteIntArray* ret = + (TfLiteIntArray*)malloc(sizeof(*ret) + sizeof(ret->data[0]) * size); + ret->size = size; + return ret; +} + +void TfLiteIntArrayPrint(const char* s, TfLiteIntArray* a) { + printf("%s: length=%d [", s, a->size); + if (a->size) printf("%d", a->data[0]); + int i = 1; + for (; i < a->size; i++) { + printf(" %d", a->data[i]); + } + printf("]\n"); +} + +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) { + if (a == b) return 1; + if (a == NULL || b == NULL) return 0; + if (a->size != b->size) return 0; + int i = 0; + for (; i < a->size; i++) + if (a->data[i] != b->data[i]) return 0; + return 1; +} + +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) { + if (!src) return NULL; + TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size); + if (ret) { + memcpy(ret->data, src->data, src->size * sizeof(int)); + } + return ret; +} + +void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); } + +void TfLiteTensorFree(TfLiteTensor* t) { + if (t->allocation_type == kTfLiteDynamic && t->data.raw) { + free(t->data.raw); + } + if (t->dims) TfLiteIntArrayFree(t->dims); + t->data.raw = NULL; + t->dims = NULL; +} + +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, TfLiteTensor* tensor) { + TfLiteTensorFree(tensor); + tensor->type = type; + tensor->name = name; + tensor->dims = dims; + tensor->params = quantization; + tensor->data.raw = buffer; + tensor->bytes = size; + tensor->allocation_type = allocation_type; + tensor->allocation = allocation; +} + +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { + if (tensor->allocation_type != kTfLiteDynamic) { + return; + } + if (!tensor->data.raw) { + tensor->data.raw = malloc(num_bytes); + } else if (num_bytes > tensor->bytes) { + tensor->data.raw = realloc(tensor->data.raw, num_bytes); + } + tensor->bytes = num_bytes; +} diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h new file mode 100644 index 0000000000000000000000000000000000000000..41257a53b145cbe7e252c9d4de6ea7ef654431b5 --- /dev/null +++ b/tensorflow/contrib/lite/context.h @@ -0,0 +1,298 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file defines a C API for implementing operations in tflite. +// These operations can be defined using c++ but the interface between +// the interpreter and the operations are C. +// +// Summary of abstractions +// TF_LITE_ENSURE - Self-sufficient error checking +// TfLiteStatus - Status reporting +// TfLiteIntArray - stores tensor shapes (dims), +// TfLiteContext - allows an op to access the tensors +// TfLiteTensor - tensor (a multidimensional array) +// TfLiteNode - a single node or operation +// TfLiteRegistration - the implementation of a conceptual operation. +// +// Some abstractions in this file are created and managed by Interpreter. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; + +#define kOptionalTensor (-1) + +// Fixed size list of integers. Used for dimensions and inputs/outputs tensor +// indices +typedef struct { + int size; +// gcc 6.1+ have a bug where flexible members aren't properly handled +// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c +#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ + __GNUC_MINOR__ >= 1 + int data[0]; +#else + int data[]; +#endif +} TfLiteIntArray; + +// Create a array of a given `size` (uninitialized entries). +// This returns a pointer, that you must free using TfLiteIntArrayFree(). +TfLiteIntArray* TfLiteIntArrayCreate(int size); + +// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise. +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b); + +// Create a copy of an array passed as `src`. +// You are expected to free memory with TfLiteIntArrayFree +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src); + +// Free memory of array `v`. +void TfLiteIntArrayFree(TfLiteIntArray* v); + +// Since we must not depend on any libraries, define a minimal subset of +// error macros while avoiding names that have pre-conceived meanings like +// assert and check. + +// Check whether value is true, and if not return kTfLiteError from +// the current function (and report the error string msg). +#define TF_LITE_ENSURE_MSG(context, value, msg) \ + do { \ + if (!(value)) { \ + (context)->ReportError((context), __FILE__ " " msg); \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +#define TF_LITE_ENSURE(context, a) \ + do { \ + if (!(a)) { \ + (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \ + __LINE__, #a); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_STATUS(a) \ + do { \ + if ((a) != kTfLiteOk) { \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a == b` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +// `a` and `b` may be evaluated more than once, so no side effects or +// extremely expensive computations should be done. +#define TF_LITE_ENSURE_EQ(context, a, b) \ + do { \ + if ((a) != (b)) { \ + (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \ + __LINE__, #a, #b, (a), (b)); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_OK(context, status) \ + do { \ + if ((status) != kTfLiteOk) { \ + return status; \ + } \ + } while (0) + +// Types supported by tensor +typedef enum { + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, +} TfLiteType; + +// Parameters for asymmetric quantization. Quantized values can be converted +// back to float using: +// real_value = scale * (quantized_value - zero_point); +typedef struct { + float scale; + int32_t zero_point; +} TfLiteQuantizationParams; + +// A union of points that points to memory for a given tensor. +typedef union { + int* i32; + float* f; + char* raw; + const char* raw_const; + uint8_t* uint8; +} TfLitePtrUnion; + +// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped +// data (or data externally allocated). kTfLiteArenaRw is arena allocated +// data. kTfLiteDynamic is for tensors that are allocated during evaluation. +typedef enum { + kTfLiteMemNone = 0, + kTfLiteMmapRo, + kTfLiteArenaRw, + kTfLiteArenaRwPersistent, + kTfLiteDynamic, +} TfLiteAllocationType; + +// An tensor in the interpreter system which is a wrapper around a buffer of +// data including a dimensionality (or NULL if not currently defined). +typedef struct { + // The data type specification for data stored in `data`. This affects + // what member of `data` union should be used. + TfLiteType type; + // A union of data pointers. The appropriate type should be used for a typed + // tensor based on `type`. + TfLitePtrUnion data; + // A pointer to a structure representing the dimensionality interpretation + // that the buffer should have. NOTE: the product of elements of `dims` + // and the element datatype size should be equal to `bytes` below. + TfLiteIntArray* dims; + // Quantization information. + TfLiteQuantizationParams params; + // How memory is mapped + // kTfLiteMmapRo: Memory mapped read only. + // i.e. weights + // kTfLiteArenaRw: Arena allocated read write memory + // (i.e. temporaries, outputs). + TfLiteAllocationType allocation_type; + // The number of bytes required to store the data of this Tensor. I.e. + // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if + // type is kTfLiteFloat32 and dims = {3, 2} then + // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. + size_t bytes; + + // An opaque pointer to a tflite::MMapAllocation + const void* allocation; + + // Null-terminated name of this tensor. + const char* name; +} TfLiteTensor; + +// Free memory of tensor `t`; +void TfLiteTensorFree(TfLiteTensor* t); + +// Set all of a tensor's fields (and free any previously allocated data). +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, TfLiteTensor* tensor); + +// Resize the allocated data of a (dynamic) tensor. +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); + +typedef struct TfLiteContext { + // Number of tensors in the context. + int tensors_size; + // An tensor of tensors in the interpreter context (of length `tensors_size`) + TfLiteTensor* tensors; + + // opaque full context ptr (an opaque c++ data structure) + void* impl_; + + // Request memory pointer be resized. Updates dimensions on the tensor. + // NOTE: ResizeTensor takes ownership of newSize. + TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor, + TfLiteIntArray* new_size); + // Request that a error be reported with format string msg. + void (*ReportError)(struct TfLiteContext*, const char* msg, ...); + + // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If + // non-null, the value pointed to by `first_new_tensor_index` will be set to + // the index of the first new tensor. + TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, + int* first_new_tensor_index); + + // TODO(ahentz): we should create a more general mechanism for this sort of + // library-global objects. + void* gemm_context; +} TfLiteContext; + +// A structure representing an instance of a node. +// This structure only exhibits the inputs, outputs and user defined data, not +// other features like the type. +typedef struct { + // Inputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* inputs; + + // Outputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* outputs; + + // Temporary tensors uses during the computations. This usually contains no + // tensors, but ops are allowed to change that if they need scratch space of + // any sort. + TfLiteIntArray* temporaries; + + // Opaque data provided by the node implementer through `Registration.init`. + void* user_data; + + // Opaque data provided to the node if the node is a builtin. + void* builtin_data; +} TfLiteNode; + +typedef struct { + // Initializes the op from serialized data. + // If a built-in op: + // `buffer` is the op's params data (TfLiteLSTMParams*). + // `length` is zero. + // If custom op: + // `buffer` is the op's `custom_options`. + // `length` is the size of the buffer. + // + // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer + // or an instance of a struct). + // + // The returned pointer will be stored with the node in the `user_data` field, + // accessible within prepare and invoke functions below. + // NOTE: if the data is already in the desired format, simply implement this + // function to return `nullptr` and implement the free function to be a no-op. + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + + // The pointer `buffer` is the data previously returned by an init invocation. + void (*free)(TfLiteContext* context, void* buffer); + + // prepare is called when the inputs this node depends on have been resized. + // context->ResizeTensor() can be called to request output tensors to be + // resized. + // + // Returns kTfLiteOk on success. + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + + // Execute the node (should read node->inputs and output to node->outputs). + // Returns kTfLiteOk on success. + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + + // Builtin codes. If this kernel refers to a builtin this is the code + // of the builtin. This is so we can do marshaling to other frameworks like + // NN API. Note, it is the responsibility of the registration binder to + // set this properly. + int32_t builtin_code; +} TfLiteRegistration; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/context_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..20d6f69a25e9f0bb4323cf5d067b8ebd37bb3c23 --- /dev/null +++ b/tensorflow/contrib/lite/context_test.cc @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/context.h" +#include +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { + +// NOTE: this tests only the TfLiteIntArray part of context. +// most of context.h is provided in the context of using it with interpreter.h +// and interpreter.cc, so interpreter_test.cc tests context structures more +// thoroughly. + +TEST(IntArray, TestIntArrayCreate) { + TfLiteIntArray* a = TfLiteIntArrayCreate(0); + TfLiteIntArray* b = TfLiteIntArrayCreate(3); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); +} + +TEST(IntArray, TestIntArrayCopy) { + TfLiteIntArray* a = TfLiteIntArrayCreate(2); + a->data[0] = 22; + a->data[1] = 24; + TfLiteIntArray* b = TfLiteIntArrayCopy(a); + ASSERT_NE(a, b); + ASSERT_EQ(a->size, b->size); + ASSERT_EQ(a->data[0], b->data[0]); + ASSERT_EQ(a->data[1], b->data[1]); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); +} + +TEST(IntArray, TestIntArrayEqual) { + TfLiteIntArray* a = TfLiteIntArrayCreate(1); + a->data[0] = 1; + TfLiteIntArray* b = TfLiteIntArrayCreate(2); + b->data[0] = 5; + b->data[1] = 6; + TfLiteIntArray* c = TfLiteIntArrayCreate(2); + c->data[0] = 5; + c->data[1] = 6; + TfLiteIntArray* d = TfLiteIntArrayCreate(2); + d->data[0] = 6; + d->data[1] = 6; + ASSERT_FALSE(TfLiteIntArrayEqual(a, b)); + ASSERT_TRUE(TfLiteIntArrayEqual(b, c)); + ASSERT_TRUE(TfLiteIntArrayEqual(b, b)); + ASSERT_FALSE(TfLiteIntArrayEqual(c, d)); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); + TfLiteIntArrayFree(c); + TfLiteIntArrayFree(d); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh new file mode 100755 index 0000000000000000000000000000000000000000..7fce1ba3461066e6dada95246781440258d844c1 --- /dev/null +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +DOWNLOADS_DIR=tensorflow/contrib/lite/downloads +BZL_FILE_PATH=tensorflow/workspace.bzl + +# Ensure it is being run from repo root +if [ ! -f $BZL_FILE_PATH ]; then + echo "Could not find ${BZL_FILE_PATH}": + echo "Likely you are not running this from the root directory of the repository."; + exit 1; +fi + +EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" +ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" +NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" +FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" +FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip" +MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_ios_lite_float_2017_11_08.zip" +QUANTIZED_MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" + +# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, +# so work around it by patching the source. +replace_by_sed() { + local regex="${1}" + shift + # Detect the version of sed by the return value of "--version" flag. GNU-sed + # supports "--version" while BSD-sed doesn't. + if ! sed --version >/dev/null 2>&1; then + # BSD-sed. + sed -i '' -e "${regex}" "$@" + else + # GNU-sed. + sed -i -e "${regex}" "$@" + fi +} + +download_and_extract() { + local usage="Usage: download_and_extract URL DIR" + local url="${1:?${usage}}" + local dir="${2:?${usage}}" + echo "downloading ${url}" >&2 + mkdir -p "${dir}" + if [[ "${url}" == *gz ]]; then + curl -Ls "${url}" | tar -C "${dir}" --strip-components=1 -xz + elif [[ "${url}" == *zip ]]; then + tempdir=$(mktemp -d) + tempdir2=$(mktemp -d) + + curl -L ${url} > ${tempdir}/zipped.zip + unzip ${tempdir}/zipped.zip -d ${tempdir2} + + # If the zip file contains nested directories, extract the files from the + # inner directory. + if ls ${tempdir2}/*/* 1> /dev/null 2>&1; then + # unzip has no strip components, so unzip to a temp dir, and move the + # files we want from the tempdir to destination. + cp -R ${tempdir2}/*/* ${dir}/ + else + cp -R ${tempdir2}/* ${dir}/ + fi + rm -rf ${tempdir2} ${tempdir} + fi + + # Delete any potential BUILD files, which would interfere with Bazel builds. + find "${dir}" -type f -name '*BUILD' -delete +} + +download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen" +download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" +download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest" +download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl" +download_and_extract "${NEON_2_SSE_URL}" "${DOWNLOADS_DIR}/neon_2_sse" +download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" +download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +download_and_extract "${MODELS_URL}" "${DOWNLOADS_DIR}/models" +download_and_extract "${QUANTIZED_MODELS_URL}" "${DOWNLOADS_DIR}/quantized_models" + +replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ + "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" +replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#static uint32x2_t p2ui_CONJ_XOR;// = vld1_u32( conj_XOR_DATA ); - Removed by scripts#' \ + "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" +replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DATA );#static uint64x2_t p2ul_CONJ_XOR;// = vld1q_u64( p2ul_conj_XOR_DATA ); - Removed by script#' \ + "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" + +cp ${DOWNLOADS_DIR}/models/models/* tensorflow/contrib/lite/examples/ios/simple/data/ +cp ${DOWNLOADS_DIR}/quantized_models/* tensorflow/contrib/lite/examples/ios/camera/data/ + +echo "download_dependencies.sh completed successfully." >&2 diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ba5384a94dbf9de03fb2e4e2f63074525eafa2d --- /dev/null +++ b/tensorflow/contrib/lite/error_reporter.cc @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/error_reporter.h" +#include +#include + +namespace tflite { + +ErrorReporter::~ErrorReporter() {} + +int ErrorReporter::Report(const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +// TODO(aselle): Make the name of ReportError on context the same, so +// we can use the ensure functions w/o a context and w/ a reporter. +int ErrorReporter::ReportError(void*, const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +int StderrReporter::Report(const char* format, va_list args) { + return vfprintf(stderr, format, args); +} + +ErrorReporter* DefaultErrorReporter() { + static StderrReporter* error_reporter = new StderrReporter; + return error_reporter; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h new file mode 100644 index 0000000000000000000000000000000000000000..637d456ce7a754c7da34e551869e49b4efd18e3b --- /dev/null +++ b/tensorflow/contrib/lite/error_reporter.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ + +#include +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// A functor that reports error to supporting system. Invoked similar to +// printf. +// +// Usage: +// ErrorReporter foo; +// foo.Report("test %d\n", 5); +// or +// va_list args; +// foo.Report("test %d\n", args); // where args is va_list +// +// Sublclass ErrorReporter to provide another reporting destination. +// For example, if you have a GUI program, you might redirect to a buffer +// that drives a GUI error log box. +class ErrorReporter { + public: + virtual ~ErrorReporter(); + virtual int Report(const char* format, va_list args) = 0; + int Report(const char* format, ...); + int ReportError(void*, const char* format, ...); +}; + +// An error reporter that simplify writes the message to stderr. +struct StderrReporter : public ErrorReporter { + int Report(const char* format, va_list args) override; +}; + +// Return the default error reporter (output to stderr). +ErrorReporter* DefaultErrorReporter(); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/examples/ios/camera/.gitignore b/tensorflow/contrib/lite/examples/ios/camera/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9e8962f4c63562dd95896833f563abfbfb578ccc --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/.gitignore @@ -0,0 +1,2 @@ +/data/*.txt +/data/*.tflite diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.h b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.h new file mode 100644 index 0000000000000000000000000000000000000000..55891c3ee18318037fd14fe4160c6f012aeaae66 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface CameraExampleAppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow* window; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.m b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.m new file mode 100644 index 0000000000000000000000000000000000000000..128266d53f560f3009f6435939ab48ae1c117a3a --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.m @@ -0,0 +1,44 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "CameraExampleAppDelegate.h" + +@implementation CameraExampleAppDelegate + +@synthesize window = _window; + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + [self.window makeKeyAndVisible]; + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { + [[UIApplication sharedApplication] setIdleTimerDisabled:NO]; +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { + [[UIApplication sharedApplication] setIdleTimerDisabled:YES]; +} + +- (void)applicationWillTerminate:(UIApplication *)application { +} + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h new file mode 100644 index 0000000000000000000000000000000000000000..fb5800e86d365b56f1b52147c3f9cc8d7211f8c3 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h @@ -0,0 +1,48 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" + +@interface CameraExampleViewController + : UIViewController { + IBOutlet UIView* previewView; + AVCaptureVideoPreviewLayer* previewLayer; + AVCaptureVideoDataOutput* videoDataOutput; + dispatch_queue_t videoDataOutputQueue; + UIView* flashView; + BOOL isUsingFrontFacingCamera; + NSMutableDictionary* oldPredictionValues; + NSMutableArray* labelLayers; + AVCaptureSession* session; + + std::vector labels; + std::unique_ptr model; + tflite::ops::builtin::BuiltinOpResolver resolver; + std::unique_ptr interpreter; + + double total_latency; + int total_count; +} +@property(strong, nonatomic) CATextLayer* predictionTextLayer; + +- (IBAction)takePicture:(id)sender; +- (IBAction)switchCameras:(id)sender; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm new file mode 100644 index 0000000000000000000000000000000000000000..10f31bb6f17242c9f7f70f0648ec643f99c5ac86 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -0,0 +1,510 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "CameraExampleViewController.h" +#import +#import +#import +#import + +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +#define LOG(x) std::cerr + +// If you have your own model, modify this to the file name, and make sure +// you've added the file to your app resources too. +static NSString* model_file_name = @"mobilenet_quant_v1_224"; +static NSString* model_file_type = @"tflite"; + +// If you have your own model, point this to the labels file. +static NSString* labels_file_name = @"labels"; +static NSString* labels_file_type = @"txt"; + +// These dimensions need to match those the model was trained with. +static const int wanted_input_width = 224; +static const int wanted_input_height = 224; +static const int wanted_input_channels = 3; + +static NSString* FilePathForResourceName(NSString* name, NSString* extension) { + NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; + } + return file_path; +} + +static void LoadLabels(NSString* file_name, NSString* file_type, + std::vector* label_strings) { + NSString* labels_path = FilePathForResourceName(file_name, file_type); + if (!labels_path) { + LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String] + << [file_type UTF8String]; + } + std::ifstream t; + t.open([labels_path UTF8String]); + std::string line; + while (t) { + std::getline(t, line); + label_strings->push_back(line); + } + t.close(); +} + +// Returns the top N confidence values over threshold in the provided vector, +// sorted by confidence in descending order. +static void GetTopN(const uint8_t* prediction, const int prediction_size, const int num_results, + const float threshold, std::vector>* top_results) { + // Will contain top N results in ascending order. + std::priority_queue, std::vector>, + std::greater>> + top_result_pq; + + const long count = prediction_size; + for (int i = 0; i < count; ++i) { + const float value = prediction[i] / 255.0; + // Only add it if it beats the threshold and has a chance at being in + // the top N. + if (value < threshold) { + continue; + } + + top_result_pq.push(std::pair(value, i)); + + // If at capacity, kick the smallest value out. + if (top_result_pq.size() > num_results) { + top_result_pq.pop(); + } + } + + // Copy to output vector and reverse into descending order. + while (!top_result_pq.empty()) { + top_results->push_back(top_result_pq.top()); + top_result_pq.pop(); + } + std::reverse(top_results->begin(), top_results->end()); +} + +@interface CameraExampleViewController (InternalMethods) +- (void)setupAVCapture; +- (void)teardownAVCapture; +@end + +@implementation CameraExampleViewController + +- (void)setupAVCapture { + NSError* error = nil; + + session = [AVCaptureSession new]; + if ([[UIDevice currentDevice] userInterfaceIdiom] == UIUserInterfaceIdiomPhone) + [session setSessionPreset:AVCaptureSessionPreset640x480]; + else + [session setSessionPreset:AVCaptureSessionPresetPhoto]; + + AVCaptureDevice* device = [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo]; + AVCaptureDeviceInput* deviceInput = + [AVCaptureDeviceInput deviceInputWithDevice:device error:&error]; + + if (error != nil) { + NSLog(@"Failed to initialize AVCaptureDeviceInput. Note: This app doesn't work with simulator"); + assert(NO); + } + + if ([session canAddInput:deviceInput]) [session addInput:deviceInput]; + + videoDataOutput = [AVCaptureVideoDataOutput new]; + + NSDictionary* rgbOutputSettings = + [NSDictionary dictionaryWithObject:[NSNumber numberWithInt:kCMPixelFormat_32BGRA] + forKey:(id)kCVPixelBufferPixelFormatTypeKey]; + [videoDataOutput setVideoSettings:rgbOutputSettings]; + [videoDataOutput setAlwaysDiscardsLateVideoFrames:YES]; + videoDataOutputQueue = dispatch_queue_create("VideoDataOutputQueue", DISPATCH_QUEUE_SERIAL); + [videoDataOutput setSampleBufferDelegate:self queue:videoDataOutputQueue]; + + if ([session canAddOutput:videoDataOutput]) [session addOutput:videoDataOutput]; + [[videoDataOutput connectionWithMediaType:AVMediaTypeVideo] setEnabled:YES]; + + previewLayer = [[AVCaptureVideoPreviewLayer alloc] initWithSession:session]; + [previewLayer setBackgroundColor:[[UIColor blackColor] CGColor]]; + [previewLayer setVideoGravity:AVLayerVideoGravityResizeAspect]; + CALayer* rootLayer = [previewView layer]; + [rootLayer setMasksToBounds:YES]; + [previewLayer setFrame:[rootLayer bounds]]; + [rootLayer addSublayer:previewLayer]; + [session startRunning]; + + if (error) { + NSString* title = [NSString stringWithFormat:@"Failed with error %d", (int)[error code]]; + UIAlertController* alertController = + [UIAlertController alertControllerWithTitle:title + message:[error localizedDescription] + preferredStyle:UIAlertControllerStyleAlert]; + UIAlertAction* dismiss = + [UIAlertAction actionWithTitle:@"Dismiss" style:UIAlertActionStyleDefault handler:nil]; + [alertController addAction:dismiss]; + [self presentViewController:alertController animated:YES completion:nil]; + [self teardownAVCapture]; + } +} + +- (void)teardownAVCapture { + [previewLayer removeFromSuperlayer]; +} + +- (AVCaptureVideoOrientation)avOrientationForDeviceOrientation: + (UIDeviceOrientation)deviceOrientation { + AVCaptureVideoOrientation result = (AVCaptureVideoOrientation)(deviceOrientation); + if (deviceOrientation == UIDeviceOrientationLandscapeLeft) + result = AVCaptureVideoOrientationLandscapeRight; + else if (deviceOrientation == UIDeviceOrientationLandscapeRight) + result = AVCaptureVideoOrientationLandscapeLeft; + return result; +} + +- (IBAction)takePicture:(id)sender { + if ([session isRunning]) { + [session stopRunning]; + [sender setTitle:@"Continue" forState:UIControlStateNormal]; + + flashView = [[UIView alloc] initWithFrame:[previewView frame]]; + [flashView setBackgroundColor:[UIColor whiteColor]]; + [flashView setAlpha:0.f]; + [[[self view] window] addSubview:flashView]; + + [UIView animateWithDuration:.2f + animations:^{ + [flashView setAlpha:1.f]; + } + completion:^(BOOL finished) { + [UIView animateWithDuration:.2f + animations:^{ + [flashView setAlpha:0.f]; + } + completion:^(BOOL finished) { + [flashView removeFromSuperview]; + flashView = nil; + }]; + }]; + + } else { + [session startRunning]; + [sender setTitle:@"Freeze Frame" forState:UIControlStateNormal]; + } +} + +- (void)captureOutput:(AVCaptureOutput*)captureOutput + didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer + fromConnection:(AVCaptureConnection*)connection { + CVPixelBufferRef pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer); + CFRetain(pixelBuffer); + [self runModelOnFrame:pixelBuffer]; + CFRelease(pixelBuffer); +} + +- (void)runModelOnFrame:(CVPixelBufferRef)pixelBuffer { + assert(pixelBuffer != NULL); + + OSType sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); + int doReverseChannels; + if (kCVPixelFormatType_32ARGB == sourcePixelFormat) { + doReverseChannels = 1; + } else if (kCVPixelFormatType_32BGRA == sourcePixelFormat) { + doReverseChannels = 0; + } else { + assert(false); // Unknown source format + } + + const int sourceRowBytes = (int)CVPixelBufferGetBytesPerRow(pixelBuffer); + const int image_width = (int)CVPixelBufferGetWidth(pixelBuffer); + const int fullHeight = (int)CVPixelBufferGetHeight(pixelBuffer); + + CVPixelBufferLockFlags unlockFlags = kNilOptions; + CVPixelBufferLockBaseAddress(pixelBuffer, unlockFlags); + + unsigned char* sourceBaseAddr = (unsigned char*)(CVPixelBufferGetBaseAddress(pixelBuffer)); + int image_height; + unsigned char* sourceStartAddr; + if (fullHeight <= image_width) { + image_height = fullHeight; + sourceStartAddr = sourceBaseAddr; + } else { + image_height = image_width; + const int marginY = ((fullHeight - image_width) / 2); + sourceStartAddr = (sourceBaseAddr + (marginY * sourceRowBytes)); + } + const int image_channels = 4; + assert(image_channels >= wanted_input_channels); + uint8_t* in = sourceStartAddr; + + int input = interpreter->inputs()[0]; + + uint8_t* out = interpreter->typed_tensor(input); + for (int y = 0; y < wanted_input_height; ++y) { + uint8_t* out_row = out + (y * wanted_input_width * wanted_input_channels); + for (int x = 0; x < wanted_input_width; ++x) { + const int in_x = (y * image_width) / wanted_input_width; + const int in_y = (x * image_height) / wanted_input_height; + uint8_t* in_pixel = in + (in_y * image_width * image_channels) + (in_x * image_channels); + uint8_t* out_pixel = out_row + (x * wanted_input_channels); + for (int c = 0; c < wanted_input_channels; ++c) { + out_pixel[c] = in_pixel[c]; + } + } + } + + double startTimestamp = [[NSDate new] timeIntervalSince1970]; + if (interpreter->Invoke() != kTfLiteOk) { + LOG(FATAL) << "Failed to invoke!"; + } + double endTimestamp = [[NSDate new] timeIntervalSince1970]; + total_latency += (endTimestamp - startTimestamp); + total_count += 1; + NSLog(@"Time: %.4lf, avg: %.4lf, count: %d", endTimestamp - startTimestamp, + total_latency / total_count, total_count); + + const int output_size = 1000; + const int kNumResults = 5; + const float kThreshold = 0.1f; + + std::vector> top_results; + + uint8_t* output = interpreter->typed_output_tensor(0); + GetTopN(output, output_size, kNumResults, kThreshold, &top_results); + + NSMutableDictionary* newValues = [NSMutableDictionary dictionary]; + for (const auto& result : top_results) { + const float confidence = result.first; + const int index = result.second; + NSString* labelObject = [NSString stringWithUTF8String:labels[index].c_str()]; + NSNumber* valueObject = [NSNumber numberWithFloat:confidence]; + [newValues setObject:valueObject forKey:labelObject]; + } + dispatch_async(dispatch_get_main_queue(), ^(void) { + [self setPredictionValues:newValues]; + }); + + CVPixelBufferUnlockBaseAddress(pixelBuffer, unlockFlags); + + CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); +} + +- (void)dealloc { + [self teardownAVCapture]; +} + +- (void)didReceiveMemoryWarning { + [super didReceiveMemoryWarning]; +} + +- (void)viewDidLoad { + [super viewDidLoad]; + labelLayers = [[NSMutableArray alloc] init]; + oldPredictionValues = [[NSMutableDictionary alloc] init]; + + NSString* graph_path = FilePathForResourceName(model_file_name, @"tflite"); + model = tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]); + if (!model) { + LOG(FATAL) << "Failed to mmap model " << graph_path; + } + LOG(INFO) << "Loaded model " << graph_path; + model->error_reporter(); + LOG(INFO) << "resolved reporter"; + + tflite::ops::builtin::BuiltinOpResolver resolver; + LoadLabels(labels_file_name, labels_file_type, &labels); + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + LOG(FATAL) << "Failed to construct interpreter"; + } + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(FATAL) << "Failed to allocate tensors!"; + } + + [self setupAVCapture]; +} + +- (void)viewDidUnload { + [super viewDidUnload]; +} + +- (void)viewWillAppear:(BOOL)animated { + [super viewWillAppear:animated]; +} + +- (void)viewDidAppear:(BOOL)animated { + [super viewDidAppear:animated]; +} + +- (void)viewWillDisappear:(BOOL)animated { + [super viewWillDisappear:animated]; +} + +- (void)viewDidDisappear:(BOOL)animated { + [super viewDidDisappear:animated]; +} + +- (BOOL)shouldAutorotateToInterfaceOrientation:(UIInterfaceOrientation)interfaceOrientation { + return (interfaceOrientation == UIInterfaceOrientationPortrait); +} + +- (BOOL)prefersStatusBarHidden { + return YES; +} + +- (void)setPredictionValues:(NSDictionary*)newValues { + const float decayValue = 0.75f; + const float updateValue = 0.25f; + const float minimumThreshold = 0.01f; + + NSMutableDictionary* decayedPredictionValues = [[NSMutableDictionary alloc] init]; + for (NSString* label in oldPredictionValues) { + NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + const float decayedPredictionValue = (oldPredictionValue * decayValue); + if (decayedPredictionValue > minimumThreshold) { + NSNumber* decayedPredictionValueObject = [NSNumber numberWithFloat:decayedPredictionValue]; + [decayedPredictionValues setObject:decayedPredictionValueObject forKey:label]; + } + } + oldPredictionValues = decayedPredictionValues; + + for (NSString* label in newValues) { + NSNumber* newPredictionValueObject = [newValues objectForKey:label]; + NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label]; + if (!oldPredictionValueObject) { + oldPredictionValueObject = [NSNumber numberWithFloat:0.0f]; + } + const float newPredictionValue = [newPredictionValueObject floatValue]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + const float updatedPredictionValue = (oldPredictionValue + (newPredictionValue * updateValue)); + NSNumber* updatedPredictionValueObject = [NSNumber numberWithFloat:updatedPredictionValue]; + [oldPredictionValues setObject:updatedPredictionValueObject forKey:label]; + } + NSArray* candidateLabels = [NSMutableArray array]; + for (NSString* label in oldPredictionValues) { + NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + if (oldPredictionValue > 0.05f) { + NSDictionary* entry = @{@"label" : label, @"value" : oldPredictionValueObject}; + candidateLabels = [candidateLabels arrayByAddingObject:entry]; + } + } + NSSortDescriptor* sort = [NSSortDescriptor sortDescriptorWithKey:@"value" ascending:NO]; + NSArray* sortedLabels = + [candidateLabels sortedArrayUsingDescriptors:[NSArray arrayWithObject:sort]]; + + const float leftMargin = 10.0f; + const float topMargin = 10.0f; + + const float valueWidth = 48.0f; + const float valueHeight = 18.0f; + + const float labelWidth = 246.0f; + const float labelHeight = 18.0f; + + const float labelMarginX = 5.0f; + const float labelMarginY = 5.0f; + + [self removeAllLabelLayers]; + + int labelCount = 0; + for (NSDictionary* entry in sortedLabels) { + NSString* label = [entry objectForKey:@"label"]; + NSNumber* valueObject = [entry objectForKey:@"value"]; + const float value = [valueObject floatValue]; + const float originY = topMargin + ((labelHeight + labelMarginY) * labelCount); + const int valuePercentage = (int)roundf(value * 100.0f); + + const float valueOriginX = leftMargin; + NSString* valueText = [NSString stringWithFormat:@"%d%%", valuePercentage]; + + [self addLabelLayerWithText:valueText + originX:valueOriginX + originY:originY + width:valueWidth + height:valueHeight + alignment:kCAAlignmentRight]; + + const float labelOriginX = (leftMargin + valueWidth + labelMarginX); + + [self addLabelLayerWithText:[label capitalizedString] + originX:labelOriginX + originY:originY + width:labelWidth + height:labelHeight + alignment:kCAAlignmentLeft]; + + labelCount += 1; + if (labelCount > 4) { + break; + } + } +} + +- (void)removeAllLabelLayers { + for (CATextLayer* layer in labelLayers) { + [layer removeFromSuperlayer]; + } + [labelLayers removeAllObjects]; +} + +- (void)addLabelLayerWithText:(NSString*)text + originX:(float)originX + originY:(float)originY + width:(float)width + height:(float)height + alignment:(NSString*)alignment { + CFTypeRef font = (CFTypeRef) @"Menlo-Regular"; + const float fontSize = 12.0; + const float marginSizeX = 5.0f; + const float marginSizeY = 2.0f; + + const CGRect backgroundBounds = CGRectMake(originX, originY, width, height); + const CGRect textBounds = CGRectMake((originX + marginSizeX), (originY + marginSizeY), + (width - (marginSizeX * 2)), (height - (marginSizeY * 2))); + + CATextLayer* background = [CATextLayer layer]; + [background setBackgroundColor:[UIColor blackColor].CGColor]; + [background setOpacity:0.5f]; + [background setFrame:backgroundBounds]; + background.cornerRadius = 5.0f; + + [[self.view layer] addSublayer:background]; + [labelLayers addObject:background]; + + CATextLayer* layer = [CATextLayer layer]; + [layer setForegroundColor:[UIColor whiteColor].CGColor]; + [layer setFrame:textBounds]; + [layer setAlignmentMode:alignment]; + [layer setWrapped:YES]; + [layer setFont:font]; + [layer setFontSize:fontSize]; + layer.contentsScale = [[UIScreen mainScreen] scale]; + [layer setString:text]; + + [[self.view layer] addSublayer:layer]; + [labelLayers addObject:layer]; +} + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/Info.plist b/tensorflow/contrib/lite/examples/ios/camera/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..f3d96bab162a707df4df8655354af5a54d1e985e --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/Info.plist @@ -0,0 +1,44 @@ + + + + + CFBundleDevelopmentRegion + en + CFBundleDisplayName + tflite_camera_example + CFBundleExecutable + ${EXECUTABLE_NAME} + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + ${PRODUCT_NAME} + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleSignature + ???? + CFBundleVersion + 1.0 + LSRequiresIPhoneOS + + NSCameraUsageDescription + Capture images to detect object + UIMainStoryboardFile + MainStoryboard_iPhone + UIRequiresFullScreen + + UIStatusBarHidden + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/tensorflow/contrib/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard b/tensorflow/contrib/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..0f10a22e415bd2519e90dd6bfac8b2ad6230caab --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard @@ -0,0 +1,46 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile new file mode 100644 index 0000000000000000000000000000000000000000..4ae6fb6b94e4489f63506b05a2f348b7daafd3b7 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile @@ -0,0 +1,5 @@ +platform :ios, '8.0' +inhibit_all_warnings! + +target 'tflite_camera_example' + pod 'TensorFlow-experimental' diff --git a/tensorflow/contrib/lite/examples/ios/camera/data/.gitignore b/tensorflow/contrib/lite/examples/ios/camera/data/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tensorflow/contrib/lite/examples/ios/camera/main.mm b/tensorflow/contrib/lite/examples/ios/camera/main.mm new file mode 100644 index 0000000000000000000000000000000000000000..1a9e542f7c9a5b09be6463437c3a8e4a5afeda6d --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/main.mm @@ -0,0 +1,28 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "CameraExampleAppDelegate.h" + +int main(int argc, char* argv[]) { + int retVal = 0; + + @autoreleasepool { + retVal = + UIApplicationMain(argc, argv, nil, NSStringFromClass([CameraExampleAppDelegate class])); + } + return retVal; +} diff --git a/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..c98183276bd60d2a0ad023ba26aad12572a02786 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj @@ -0,0 +1,419 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 46; + objects = { + +/* Begin PBXBuildFile section */ + 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */; }; + 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */; }; + 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */; }; + 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */; }; + 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */; }; + 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */; }; + 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 1CDB2D4D1ED3AA35007929E9 /* Info.plist */; }; + 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */; }; + AC1F82661FBA3CBD0052BA77 /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = AC1F82641FBA3CBD0052BA77 /* labels.txt */; }; + AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */; }; + ACA1A4CA1FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; + 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; + 1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tflite_camera_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.storyboard; path = MainStoryboard_iPhone.storyboard; sourceTree = ""; }; + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; + 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreMedia.framework; path = System/Library/Frameworks/CoreMedia.framework; sourceTree = SDKROOT; }; + 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = AVFoundation.framework; path = System/Library/Frameworks/AVFoundation.framework; sourceTree = SDKROOT; }; + 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleAppDelegate.h; sourceTree = ""; }; + 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = CameraExampleAppDelegate.m; sourceTree = ""; }; + 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleViewController.h; sourceTree = ""; }; + 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CameraExampleViewController.mm; sourceTree = ""; }; + 1CDB2D4D1ED3AA35007929E9 /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tflite_camera_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; + 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.debug.xcconfig"; sourceTree = ""; }; + 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.release.xcconfig"; sourceTree = ""; }; + AC1F82641FBA3CBD0052BA77 /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = ""; }; + AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = ""; }; + ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_quant_v1_224.tflite; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 1C564C0A1ED3A92E00087306 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */, + 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */, + 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */, + 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 24D7686C331131624F4454A0 /* Frameworks */ = { + isa = PBXGroup; + children = ( + AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */, + 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */, + 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */, + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, + 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 3E9FC355632FB928EA23BEED /* Pods */ = { + isa = PBXGroup; + children = ( + 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */, + 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */, + ); + name = Pods; + sourceTree = ""; + }; + 591157921CF4011C00C31E3A = { + isa = PBXGroup; + children = ( + 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */, + 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */, + 1CDB2D4D1ED3AA35007929E9 /* Info.plist */, + 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */, + 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */, + 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */, + 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */, + 59A3CFF31CF4E68100C4259F /* data */, + 5911579C1CF4011C00C31E3A /* Products */, + 3E9FC355632FB928EA23BEED /* Pods */, + 24D7686C331131624F4454A0 /* Frameworks */, + ); + sourceTree = ""; + }; + 5911579C1CF4011C00C31E3A /* Products */ = { + isa = PBXGroup; + children = ( + 1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */, + ); + name = Products; + sourceTree = ""; + }; + 59A3CFF31CF4E68100C4259F /* data */ = { + isa = PBXGroup; + children = ( + ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */, + AC1F82641FBA3CBD0052BA77 /* labels.txt */, + ); + path = data; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 1C564C0C1ED3A92E00087306 /* tflite_camera_example */ = { + isa = PBXNativeTarget; + buildConfigurationList = 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tflite_camera_example" */; + buildPhases = ( + 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */, + 1C564C091ED3A92E00087306 /* Sources */, + 1C564C0A1ED3A92E00087306 /* Frameworks */, + 1C564C0B1ED3A92E00087306 /* Resources */, + 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */, + 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = tflite_camera_example; + productName = tflite_camera_example; + productReference = 1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 591157931CF4011C00C31E3A /* Project object */ = { + isa = PBXProject; + attributes = { + LastSwiftUpdateCheck = 0830; + LastUpgradeCheck = 0830; + ORGANIZATIONNAME = Google; + TargetAttributes = { + 1C564C0C1ED3A92E00087306 = { + CreatedOnToolsVersion = 8.3.2; + DevelopmentTeam = EQHXZ8M8AV; + ProvisioningStyle = Automatic; + }; + }; + }; + buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tflite_camera_example" */; + compatibilityVersion = "Xcode 3.2"; + developmentRegion = English; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 591157921CF4011C00C31E3A; + productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 1C564C0C1ED3A92E00087306 /* tflite_camera_example */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 1C564C0B1ED3A92E00087306 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + ACA1A4CA1FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite in Resources */, + 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */, + 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */, + AC1F82661FBA3CBD0052BA77 /* labels.txt in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXShellScriptBuildPhase section */ + 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Embed Pods Frameworks"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example-frameworks.sh\"\n"; + showEnvVarsInLog = 0; + }; + 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Copy Pods Resources"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example-resources.sh\"\n"; + showEnvVarsInLog = 0; + }; + 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-tflite_camera_example-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; +/* End PBXShellScriptBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 1C564C091ED3A92E00087306 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */, + 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */, + 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 1C564C361ED3A92E00087306 /* Debug */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + DEVELOPMENT_TEAM = EQHXZ8M8AV; + INFOPLIST_FILE = Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 10.3; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 3.0; + }; + name = Debug; + }; + 1C564C371ED3A92E00087306 /* Release */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + DEVELOPMENT_TEAM = EQHXZ8M8AV; + INFOPLIST_FILE = Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 10.3; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_OPTIMIZATION_LEVEL = "-Owholemodule"; + SWIFT_VERSION = 3.0; + }; + name = Release; + }; + 591157B01CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + MTL_ENABLE_DEBUG_INFO = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 591157B11CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + MTL_ENABLE_DEBUG_INFO = NO; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tflite_camera_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 1C564C361ED3A92E00087306 /* Debug */, + 1C564C371ED3A92E00087306 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tflite_camera_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B01CF4011D00C31E3A /* Debug */, + 591157B11CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 591157931CF4011C00C31E3A /* Project object */; +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h new file mode 100644 index 0000000000000000000000000000000000000000..94046d9728258901091f018fd0d081651145f400 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm new file mode 100644 index 0000000000000000000000000000000000000000..d1215fa0bffd978b4aaadbd8bc13b07723703c9a --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm @@ -0,0 +1,48 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +#import "RunModelViewController.h" + +@implementation AppDelegate + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + + UITabBarController *bar = [[UITabBarController alloc] init]; + [bar setViewControllers:@[ [[RunModelViewController alloc] init] ]]; + bar.selectedIndex = 0; + self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; + self.window.rootViewController = bar; + [self.window makeKeyAndVisible]; + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { +} + +- (void)applicationWillTerminate:(UIApplication *)application { +} + +@end diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile new file mode 100644 index 0000000000000000000000000000000000000000..1740ad64573a84fae6de0fcf284eb06afec67e25 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile @@ -0,0 +1,5 @@ +platform :ios, '8.0' +inhibit_all_warnings! + +target 'tf_simple_example' + pod 'TensorFlow-experimental' diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist b/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..1a3eaa8a2c18d1cd24dfd475d396b00ec4d86c9d --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist @@ -0,0 +1,47 @@ + + + + + CFBundleDevelopmentRegion + en + CFBundleDisplayName + tflite-simple-example + CFBundleExecutable + tf_simple_example + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + ios-app + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleSignature + ???? + CFBundleVersion + 1.0 + LSRequiresIPhoneOS + + UILaunchStoryboardName + RunModelViewController + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + + diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h new file mode 100644 index 0000000000000000000000000000000000000000..a4b358b4eb7f6ba109638405091b798d30bd1768 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h @@ -0,0 +1,24 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface RunModelViewController : UIViewController + +- (IBAction)getUrl:(id)sender; + +@property(weak, nonatomic) IBOutlet UITextView *urlContentTextView; +@property(weak, nonatomic) IBOutlet UITextField *urlTextField; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm new file mode 100644 index 0000000000000000000000000000000000000000..0dafb1f61e19f46bb3b17f07c55e09f5813ed560 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm @@ -0,0 +1,221 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "RunModelViewController.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +#include "ios_image_load.h" + +#define LOG(x) std::cerr +#define CHECK(x) \ + if (!(x)) { \ + LOG(ERROR) << #x << "failed"; \ + exit(1); \ + } + +NSString* RunInferenceOnImage(); + +@interface RunModelViewController () +@end + +@implementation RunModelViewController { +} + +- (IBAction)getUrl:(id)sender { + NSString* inference_result = RunInferenceOnImage(); + self.urlContentTextView.text = inference_result; +} + +@end + +// Returns the top N confidence values over threshold in the provided vector, +// sorted by confidence in descending order. +static void GetTopN(const float* prediction, const int prediction_size, const int num_results, + const float threshold, std::vector >* top_results) { + // Will contain top N results in ascending order. + std::priority_queue, std::vector >, + std::greater > > + top_result_pq; + + const long count = prediction_size; + for (int i = 0; i < count; ++i) { + const float value = prediction[i]; + + // Only add it if it beats the threshold and has a chance at being in + // the top N. + if (value < threshold) { + continue; + } + + top_result_pq.push(std::pair(value, i)); + + // If at capacity, kick the smallest value out. + if (top_result_pq.size() > num_results) { + top_result_pq.pop(); + } + } + + // Copy to output vector and reverse into descending order. + while (!top_result_pq.empty()) { + top_results->push_back(top_result_pq.top()); + top_result_pq.pop(); + } + std::reverse(top_results->begin(), top_results->end()); +} + +NSString* FilePathForResourceName(NSString* name, NSString* extension) { + NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; + } + return file_path; +} + +NSString* RunInferenceOnImage() { + std::string graph; + const int num_threads = 1; + std::string input_layer_type = "float"; + std::vector sizes = {1, 224, 224, 3}; + + NSString* graph_path = FilePathForResourceName(@"mobilenet_v1_1.0_224", @"tflite"); + + std::unique_ptr model( + tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); + if (!model) { + LOG(FATAL) << "Failed to mmap model " << graph; + } + LOG(INFO) << "Loaded model " << graph; + model->error_reporter(); + LOG(INFO) << "resolved reporter"; + +#ifdef TFLITE_CUSTOM_OPS_HEADER + tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); +#else + tflite::ops::builtin::BuiltinOpResolver resolver; +#endif + + std::unique_ptr interpreter; + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + LOG(FATAL) << "Failed to construct interpreter"; + } + + if (num_threads != -1) { + interpreter->SetNumThreads(num_threads); + } + + int input = interpreter->inputs()[0]; + + if (input_layer_type != "string") { + interpreter->ResizeInputTensor(input, sizes); + } + + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(FATAL) << "Failed to allocate tensors!"; + } + + // Read the label list + NSString* labels_path = FilePathForResourceName(@"labels", @"txt"); + std::vector label_strings; + std::ifstream t; + t.open([labels_path UTF8String]); + std::string line; + while (t) { + std::getline(t, line); + label_strings.push_back(line); + } + t.close(); + + // Read the Grace Hopper image. + NSString* image_path = FilePathForResourceName(@"grace_hopper", @"jpg"); + int image_width; + int image_height; + int image_channels; + std::vector image_data = + LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); + const int wanted_width = 224; + const int wanted_height = 224; + const int wanted_channels = 3; + const float input_mean = 127.5f; + const float input_std = 127.5f; + assert(image_channels >= wanted_channels); + uint8_t* in = image_data.data(); + float* out = interpreter->typed_tensor(input); + for (int y = 0; y < wanted_height; ++y) { + const int in_y = (y * image_height) / wanted_height; + uint8_t* in_row = in + (in_y * image_width * image_channels); + float* out_row = out + (y * wanted_width * wanted_channels); + for (int x = 0; x < wanted_width; ++x) { + const int in_x = (x * image_width) / wanted_width; + uint8_t* in_pixel = in_row + (in_x * image_channels); + float* out_pixel = out_row + (x * wanted_channels); + for (int c = 0; c < wanted_channels; ++c) { + out_pixel[c] = (in_pixel[c] - input_mean) / input_std; + } + } + } + + if (interpreter->Invoke() != kTfLiteOk) { + LOG(FATAL) << "Failed to invoke!"; + } + + float* output = interpreter->typed_output_tensor(0); + const int output_size = 1000; + const int kNumResults = 5; + const float kThreshold = 0.1f; + std::vector > top_results; + GetTopN(output, output_size, kNumResults, kThreshold, &top_results); + + std::stringstream ss; + ss.precision(3); + for (const auto& result : top_results) { + const float confidence = result.first; + const int index = result.second; + + ss << index << " " << confidence << " "; + + // Write out the result as a string + if (index < label_strings.size()) { + // just for safety: theoretically, the output is under 1000 unless there + // is some numerical issues leading to a wrong prediction. + ss << label_strings[index]; + } else { + ss << "Prediction: " << index; + } + + ss << "\n"; + } + + LOG(INFO) << "Predictions: " << ss.str(); + + std::string predictions = ss.str(); + NSString* result = @""; + result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()]; + + return result; +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.xib b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.xib new file mode 100644 index 0000000000000000000000000000000000000000..93f334b9850c6f5f22455b3d14a075c17a7c9171 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.xib @@ -0,0 +1,46 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/examples/ios/simple/data/grace_hopper.jpg b/tensorflow/contrib/lite/examples/ios/simple/data/grace_hopper.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d2a427810f679db537236c5430873a81a62ef412 Binary files /dev/null and b/tensorflow/contrib/lite/examples/ios/simple/data/grace_hopper.jpg differ diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h new file mode 100644 index 0000000000000000000000000000000000000000..98934ce41d349b33d4fc010a39a956e52f3d5721 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h @@ -0,0 +1,23 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ +#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ + +#include + +std::vector LoadImageFromFile(const char* file_name, int* out_width, + int* out_height, int* out_channels); + +#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm new file mode 100644 index 0000000000000000000000000000000000000000..cb0fe1a7650c572d3745066431f2759daa94ffc9 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm @@ -0,0 +1,82 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ios_image_load.h" + +#include +#include +#include +#include + +#import +#import + +std::vector LoadImageFromFile(const char* file_name, int* out_width, int* out_height, + int* out_channels) { + FILE* file_handle = fopen(file_name, "rb"); + fseek(file_handle, 0, SEEK_END); + const size_t bytes_in_file = ftell(file_handle); + fseek(file_handle, 0, SEEK_SET); + std::vector file_data(bytes_in_file); + fread(file_data.data(), 1, bytes_in_file, file_handle); + fclose(file_handle); + + CFDataRef file_data_ref = + CFDataCreateWithBytesNoCopy(NULL, file_data.data(), bytes_in_file, kCFAllocatorNull); + CGDataProviderRef image_provider = CGDataProviderCreateWithCFData(file_data_ref); + + const char* suffix = strrchr(file_name, '.'); + if (!suffix || suffix == file_name) { + suffix = ""; + } + CGImageRef image; + if (strcasecmp(suffix, ".png") == 0) { + image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); + } else if ((strcasecmp(suffix, ".jpg") == 0) || (strcasecmp(suffix, ".jpeg") == 0)) { + image = + CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); + } else { + CFRelease(image_provider); + CFRelease(file_data_ref); + fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); + *out_width = 0; + *out_height = 0; + *out_channels = 0; + return std::vector(); + } + + const int width = (int)CGImageGetWidth(image); + const int height = (int)CGImageGetHeight(image); + const int channels = 4; + CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); + const int bytes_per_row = (width * channels); + const int bytes_in_image = (bytes_per_row * height); + std::vector result(bytes_in_image); + const int bits_per_component = 8; + + CGContextRef context = + CGBitmapContextCreate(result.data(), width, height, bits_per_component, bytes_per_row, + color_space, kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + CGColorSpaceRelease(color_space); + CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); + CGContextRelease(context); + CFRelease(image); + CFRelease(image_provider); + CFRelease(file_data_ref); + + *out_width = width; + *out_height = height; + *out_channels = channels; + return result; +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/main.mm b/tensorflow/contrib/lite/examples/ios/simple/main.mm new file mode 100644 index 0000000000000000000000000000000000000000..05cb55ddd7a230593863e64b351f6aac31a1b4d7 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/main.mm @@ -0,0 +1,22 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +int main(int argc, char *argv[]) { + @autoreleasepool { + NSString *delegateClassName = @"AppDelegate"; + return UIApplicationMain(argc, argv, nil, delegateClassName); + } +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..9277c230b8cce1b5673a50d32d7640d52e2e8f9d --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj @@ -0,0 +1,359 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 46; + objects = { + +/* Begin PBXBuildFile section */ + 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */; }; + 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */; }; + 594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */; }; + 594C14B11FB9037100EE8BFE /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = 594C14AF1FB9037100EE8BFE /* labels.txt */; }; + 594C14B21FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */; }; + 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; + 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; + 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; }; + 59A3D0091CF4E68100C4259F /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; }; + 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */; }; + 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; + 5911579B1CF4011C00C31E3A /* tf_simple_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_simple_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = ""; }; + 594C14AF1FB9037100EE8BFE /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = ""; }; + 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = ""; }; + 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; + 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = ""; }; + 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = ""; }; + 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = ""; }; + 59A3CFFC1CF4E68100C4259F /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; + 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "RunModel-Info.plist"; sourceTree = ""; }; + 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RunModelViewController.h; sourceTree = ""; }; + 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RunModelViewController.mm; sourceTree = ""; }; + 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = RunModelViewController.xib; sourceTree = ""; }; + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_simple_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 591157981CF4011C00C31E3A /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */, + 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */, + 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 24D7686C331131624F4454A0 /* Frameworks */ = { + isa = PBXGroup; + children = ( + 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */, + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 591157921CF4011C00C31E3A = { + isa = PBXGroup; + children = ( + 59A3CFF11CF4E68100C4259F /* AppDelegate.h */, + 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */, + 59A3CFF31CF4E68100C4259F /* data */, + 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */, + 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */, + 59A3CFFC1CF4E68100C4259F /* main.mm */, + 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */, + 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */, + 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */, + 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */, + 5911579C1CF4011C00C31E3A /* Products */, + 24D7686C331131624F4454A0 /* Frameworks */, + ); + sourceTree = ""; + }; + 5911579C1CF4011C00C31E3A /* Products */ = { + isa = PBXGroup; + children = ( + 5911579B1CF4011C00C31E3A /* tf_simple_example.app */, + ); + name = Products; + sourceTree = ""; + }; + 59A3CFF31CF4E68100C4259F /* data */ = { + isa = PBXGroup; + children = ( + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */, + 594C14AF1FB9037100EE8BFE /* labels.txt */, + 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */, + ); + path = data; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 5911579A1CF4011C00C31E3A /* tf_simple_example */ = { + isa = PBXNativeTarget; + buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */; + buildPhases = ( + 591157971CF4011C00C31E3A /* Sources */, + 591157981CF4011C00C31E3A /* Frameworks */, + 591157991CF4011C00C31E3A /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = tf_simple_example; + productName = tf_ios_makefile_example; + productReference = 5911579B1CF4011C00C31E3A /* tf_simple_example.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 591157931CF4011C00C31E3A /* Project object */ = { + isa = PBXProject; + attributes = { + LastUpgradeCheck = 0830; + ORGANIZATIONNAME = Google; + TargetAttributes = { + 5911579A1CF4011C00C31E3A = { + CreatedOnToolsVersion = 7.2; + DevelopmentTeam = EQHXZ8M8AV; + ProvisioningStyle = Manual; + }; + }; + }; + buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "simple" */; + compatibilityVersion = "Xcode 3.2"; + developmentRegion = English; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 591157921CF4011C00C31E3A; + productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 5911579A1CF4011C00C31E3A /* tf_simple_example */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 591157991CF4011C00C31E3A /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */, + 594C14B11FB9037100EE8BFE /* labels.txt in Resources */, + 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */, + 594C14B21FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 591157971CF4011C00C31E3A /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 59A3D0091CF4E68100C4259F /* main.mm in Sources */, + 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */, + 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */, + 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 591157B01CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 591157B11CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = NO; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 591157B31CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + CLANG_DEBUG_INFORMATION_LEVEL = default; + CODE_SIGN_IDENTITY = "iPhone Developer"; + DEVELOPMENT_TEAM = EQHXZ8M8AV; + ENABLE_BITCODE = NO; + GCC_ENABLE_CPP_EXCEPTIONS = YES; + GCC_ENABLE_CPP_RTTI = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; + IPHONEOS_DEPLOYMENT_TARGET = 9.2; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + OTHER_LDFLAGS = "$(inherited)"; + PRODUCT_BUNDLE_IDENTIFIER = "com.google.tflite-simple-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE = "1072bd47-ff19-4e5f-8107-d912748f83f1"; + PROVISIONING_PROFILE_SPECIFIER = "Google Development"; + SEPARATE_STRIP = NO; + }; + name = Debug; + }; + 591157B41CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + CLANG_DEBUG_INFORMATION_LEVEL = default; + CODE_SIGN_IDENTITY = "iPhone Developer"; + DEVELOPMENT_TEAM = ""; + ENABLE_BITCODE = NO; + GCC_ENABLE_CPP_EXCEPTIONS = YES; + GCC_ENABLE_CPP_RTTI = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; + IPHONEOS_DEPLOYMENT_TARGET = 9.2; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + ONLY_ACTIVE_ARCH = YES; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + OTHER_LDFLAGS = "$(inherited)"; + PRODUCT_BUNDLE_IDENTIFIER = "com.google.tflite-simple-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; + SEPARATE_STRIP = NO; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "simple" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B01CF4011D00C31E3A /* Debug */, + 591157B11CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B31CF4011D00C31E3A /* Debug */, + 591157B41CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 591157931CF4011C00C31E3A /* Project object */; +} diff --git a/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg b/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bc83946647c6a923a8a0bd3a041b42e4febe6a31 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg differ diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md new file mode 100644 index 0000000000000000000000000000000000000000..fe208e47d1ac10995881e55c8596ae14ff4242df --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -0,0 +1,359 @@ +# TensorFlow Lite APIs + +TensorFlow Lite provides programming APIs in C++ and Java, and in both cases +the API design reflects a preference for performance over ease of use. +TensorFlow Lite is designed for fast inference on small devices so it should be +no surprise that the APIs try to avoid unnecessary copies at the expense of +convenience. Similarly, consistency with TensorFlow APIs was not an explicit +goal and some variance is to be expected. + +## C++ + +In order to run the inference model in TensorFlow Lite, one has to load the +model into a `FlatBufferModel` object which then can be executed by an +`Interpreter`. The `FlatBufferModel` needs to remain valid for the whole +lifetime of the `Interpreter`, and a single `FlatBufferModel` can be +simultaneously used by more than one `Interpreter`. In concrete terms, the +`FlatBufferModel` object must be created before any `Interpreter` objects that +use it, and must be kept around until they have all been destroyed. + +The simplest usage of TensorFlow Lite will look like this: + +```c++ +tflite::FlatBufferModel model(path_to_model); +tflite::ops::builtin::BuiltinOpResolver resolver; +std::unique_ptr interpreter; +tflite::InterpreterBuilder(*model, resolver)(&interpreter); +// Resize input tensors, if desired. +interpreter->AllocateTensors(); +float* input = interpreter->typed_input_tensor(0); +// Fill `input`. +interpreter->Invoke(); +float* output = interpreter->type_output_tensor(0); +``` +### Data Alignment + +TensorFlow Lite data is usually aligned to 32-bit boundaries. It is recommended +that all data provided to TensorFlow Lite be aligned that way. + +### Error Reporting + +In many places TensorFlow Lite returns status information through +`TfLiteStatus` objects: + +```c++ +typedef enum { + kTfLiteOk = 0, + kTfLiteError = 1 +} TfLiteStatus; + +``` + +Failures can be easily verified with: +```c++ +if (status != kTfLiteOk) { + // ... error handling here ... +} +``` + +In order to obtain detailed error information an ErrorReporter must be +provided: + +```c++ +class ErrorReporter { + virtual int Report(const char* format, va_list args) = 0; +}; +``` + +The `DefaultErrorReporter` takes care of reporting to `stderr`. + +### Loading a Model + +The `FlatBufferModel` class encapsulates a model and can be built in a couple of +slightly different ways depending on where the model is stored: + +```c++ +class FlatBufferModel { +  // Build a model based on a file. Return a nullptr in case of failure. +  static std::unique_ptr BuildFromFile( +      const char* filename, +      ErrorReporter* error_reporter); + +  // Build a model based on a pre-loaded flatbuffer. The caller retains +  // ownership of the buffer and should keep it alive until the returned object +  // is destroyed. Return a nullptr in case of failure. +  static std::unique_ptr BuildFromBuffer( +      const char* buffer, +      size_t buffer_size, +      ErrorReporter* error_reporter); +}; +``` + +Note that if TensorFlow Lite detects the presence of Android's NNAPI it will +automatically try to use shared memory to store the FlatBufferModel. + +### Running a Model + +Running a model involves a few simple steps: + + * Build an `Interpreter` based on an existing `FlatBufferModel` + * Optionally resize input tensors if the predefined sizes are not desired. + * Set input tensor values + * Invoke inference + * Read output tensor values + +The important parts of public interface of the `Interpreter` are provided +below. It should be noted that: + + * Tensors are represented by integers, in order to avoid string comparisons + (and any fixed dependency on string libraries). + * An interpreter must not be accessed from concurrent threads + * Memory allocation for input and output tensors must be triggered + by calling AllocateTensors() right after resizing tensors. + +```c++ +class Interpreter { + Interpreter(ErrorReporter* error_reporter); + + // Read only access to list of inputs. + const std::vector& inputs() const; + + // Read only access to list of outputs. + const std::vector& outputs() const; + + // Change the dimensionality of a given tensor. + TfLiteStatus ResizeInputTensor(int tensor_index, + const std::vector& dims); + + // Returns status of success or failure. + TfLiteStatus AllocateTensors(); + + // Return a pointer into the data of a given input tensor. + template + T* typed_input_tensor(int index) { + return typed_tensor(inputs_[index]); + } + + // Return a pointer into the data of a given output tensor. + template + T* typed_output_tensor(int index) { + return typed_tensor(outputs_[index]); + } + + // Execute the model, populating output tensors. + TfLiteStatus Invoke(); +}; +``` + +### Writing Custom Operators + +All TensorFlow Lite operators (both custom and builtin) are defined using a +simple pure-C interface that consists of four functions: + +```c++ +typedef struct { + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + void (*free)(TfLiteContext* context, void* buffer); + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); +} TfLiteRegistration; +``` + +Refer to `context.h` for details on `TfLiteContext` and `TfLiteNode`. The +former provides error reporting facilities and access to global objects, +including all the tensors. The latter allows implementations to access their +inputs and outputs. + +When the interpreter loads a model, it calls init() once for each node in the +graph. A given `init()` will be called more than once if the op is used +multiple times in the graph. For custom ops a configuration buffer will be +provided, containing a flexbuffer that maps parameter names to their values. +The buffer is empty for builtin ops because the interpreter has already parsed +the op parameters. Kernel implementation that require state should initialize +it here and transfer ownership to the caller. For each `init()` call, there +will be a corresponding call to `free()`, allowing implementations to dispose +of the buffer they might have allocated in `init()`. + +Whenever the input tensors are resized the interpreter will go through the +graph notifying implementations of the change. This gives them the chance to +resize their internal buffer, check validity of input shapes and types, and +recalculate output shapes. This is all done through `prepare()` and +implementation can access their state using `node->user_data`. + +Finally, each time inference runs the interpreter traverses the graph calling +`invoke()`, and here too the state is available as `node->user_data`. + +Custom ops can be implemented in exactly the same way as builtin ops, by +defined those four functions and a global registration function that usually +looks like this: + +```c++ +namespace tflite { +namespace ops { +namespace custom { + TfLiteRegistration* Register_MY_CUSTOM_OP() { + static TfLiteRegistration r = {my_custom_op::Init, + my_custom_op::Free, + my_custom_op::Prepare, + my_custom_op::Eval}; + return &r; + } +} // namespace custom +} // namespace ops +} // namespace tflite +``` + +Note that registration is not automatic and an explicit call to +`Register_MY_CUSTOM_OP` should be made somewhere. While the standard +`:builtin_ops` takes care of the registration of builtins, custom ops will have +to be collected in separated custom libraries. + +### Customizing the kernel library + +Behind the scenes the interpreter will load a library of kernels which will be +assigned to execute each of the operators in the model. While the default +library only contains builtin kernels, it is possible to replace it with a +custom library. + +The interpreter uses an `OpResolver` to translate operator codes and names into +actual code: + +```c++ +class OpResolver { + virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; + virtual TfLiteRegistration* FindOp(const char* op) const = 0; + virtual void AddOp(tflite::BuiltinOperator op, TfLiteRegistration* registration) = 0; + virtual void AddOp(const char* op, TfLiteRegistration* registration) = 0; +}; +``` + +The regular usage will require the developer to use the `BuiltinOpResolver` and +write: + +```c++ +tflite::ops::builtin::BuiltinOpResolver resolver; +``` + +They can then optionally register custom ops: + +```c++ +resolver.AddOp("MY_CUSTOM_OP", Register_MY_CUSTOM_OP()); +``` + +before the resolver is passed to the `InterpreterBuilder`. + +If the set of builtin ops is deemed to be too large, a new `OpResolver` could +be code-generated based on a given subset of ops, possibly only the ones +contained in a given model. This is the equivalent of TensorFlow's selective +registration (and a simple version of it is available in the `tools` +directory). + +## Java + +TensorFlow Lite's Java API supports on-device inference and is provided as an +Android Studio Library that allows loading models, feeding inputs, and +retrieving inference outputs. + +The simplest usage of Tensorflow Lite Java API looks like this: + +```java +try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { + interpreter.run(input, output); +} +``` + +### Loading a Model + +The `Interpreter.java` class drives model inference with TensorFlow Lite. In +most of the cases, this is the only class an app developer will need. + +#### Initializing an `Interpreter` With a Model File + +The `Interpreter` can be initialized with a model file using the constructor: + +```java +public Interpreter(@NotNull File modelFile); +``` + +or with a `MappedByteBuffer`: + +```java +public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer); +``` + +In both cases a valid TensorFlow Lite must be provided or an +`IllegalArgumentException` with be thrown. If a `MappedByteBuffer` is used to +initialize an Interpreter, it should remain unchanged for the whole lifetime of +the `Interpreter`. + +### Running a Model + +#### Supported Data Types + +To use TensorFlow Lite, the data types of the input and output tensors must be +one of the following primitive types: + +* `float` +* `int` +* `long` +* `byte` + +If other data types, including boxed types like `Integer` and `Float`, are used, +an `IllegalArgumentException` will be thrown. + +#### Inputs + +Each input should be an array, a multi-dimensional array, or a `ByteBuffer` of +the supported primitive types. + +The use of `ByteBuffer` is preferred since it allows the `Interpreter` to avoid +unnecessary copies. Each `ByteBuffer` needs to be a direct byte buffer, and its +order must be `ByteOrder.nativeOrder()`. After it is used for a model inference, +it must remain unchanged until the model inference is finished. + +#### Outputs + +Each output should be an array, or a multi-dimensional array of the supported +primitive types. + +#### Running Model Inference + +If a model takes only one input and returns only one output, the following will +trigger an inference run: + +```java +interpreter.run(input, output); +``` + +For models with multiple inputs, or multiple outputs, use: + +```java +interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); +``` + +where each entry in `inputs` corresponds to an input tensor and +`map_of_indices_to_outputs` maps indices of output tensors to the +corresponding output data. In both cases the tensor indices should correspond to +the values given to the `TensorFlow Lite Optimized Converter` when the model was +created. Be aware that the order of tensors in `input` must match the order +given to the `TensorFlow Lite Optimized Converter`. + +The Java API also provides convenient functions for app developers to get the +index of any model input or output using a tensor name: + +```java +public int getInputIndex(String tensorName); +public int getOutputIndex(String tensorName); +``` + +If tensorName is not a valid name in model, an `IllegalArgumentException` will +be thrown. + +### Releasing Resources After Use + +An `Interpreter` owns resources. To avoid memory leak, the resources must be +released after use by: + +```java +interpreter.close(); +``` diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md new file mode 100644 index 0000000000000000000000000000000000000000..204a489a93519309bb09238f1b2c8bbd4f1f19e4 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -0,0 +1,91 @@ +# How to use custom operators + +TensorFlow Lite currently supports a subset of TensorFlow operators. However, it +does support the use of user-provided implementations (as known as custom +implementations) if the model contains an operator that is not supported. + +Let’s walk through this via an example. Assume we are using the `Sin` operator +and that we are building a very simple model for a function `y = sin(x + +offset)`, where `offset` is trainable. + +The code to train the TensorFlow model will be something like: + +```python +offset = tf.get_variable("offset", [1,], tf.float32) +x = tf.placeholder(tf.float32, shape=(None,)) +y = tf.sin(x + offset) +y_ = tf.placeholder(tf.float32, shape=(None,)) +loss = tf.reduce_sum(tf.square(y - y_)) +optimizer = tf.train.GradientDescentOptimizer(0.001) +train = optimizer.minimize(loss) +``` + +If you convert this model to Tensorflow Lite format using the TensorFlow Lite +Optimizing Converter with `--allow_custom_ops` argument, and run it with the +default interpreter, the interpreter will raise the following error messages: + +``` +Didn't find custom op for name 'Sin' +Registration failed. +``` + +All we need to do to use the op in TensorFlow Lite is define two functions +(`Prepare` and `Eval`), and construct a `TfLiteRegistration`. This code would +look something like this: + +```cpp +TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { + using namespace tflite; + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + + int num_dims = NumDimensions(input); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims); + for (int i=0; idata[i] = input->dims->data[i]; + } + + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { + using namespace tflite; + TfLiteTensor* input = GetInput(context, node,0); + TfLiteTensor* output = GetOutput(context, node,0); + + float* input_data = input->data.f; + float* output_data = output->data.f; + + size_t count = 1; + int num_dims = NumDimensions(input); + for (int i = 0; i < num_dims; ++i) { + count *= input->dims->data[i]; + } + + for (size_t i=0; i +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" + +namespace { + +// Memory allocation tuning +constexpr const int kDefaultArenaAlignment = 64; +constexpr const int kDefaultTensorAlignment = 4; +// std::vector preallocation tuning. +constexpr const int kSlotsToReserve = 128; + +} // namespace + +namespace tflite { + +Interpreter::Interpreter(ErrorReporter* error_reporter) + : arena_(kDefaultArenaAlignment), + persistent_arena_(kDefaultArenaAlignment), + error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + context_.impl_ = static_cast(this); + context_.ResizeTensor = ResizeTensor; + context_.ReportError = ReportError; + context_.AddTensors = AddTensors; + context_.tensors = nullptr; + context_.tensors_size = 0; + context_.gemm_context = nullptr; + // Reserve some space for the tensors to avoid excessive resizing. + tensors_.reserve(kSlotsToReserve); + nodes_and_registration_.reserve(kSlotsToReserve); + next_allocate_node_id_ = 0; + UseNNAPI(false); +} + +Interpreter::~Interpreter() { + for (auto& nodeAndReg : nodes_and_registration_) { + TfLiteNode& node = nodeAndReg.first; + TfLiteIntArrayFree(node.inputs); + TfLiteIntArrayFree(node.outputs); + TfLiteIntArrayFree(node.temporaries); + if (node.builtin_data) free(node.builtin_data); + OpFree(nodeAndReg.second, node.user_data); + node.builtin_data = nullptr; + } + + for (int i = 0; i < context_.tensors_size; i++) { + TfLiteTensorFree(&context_.tensors[i]); + } +} + +TfLiteStatus Interpreter::SetInputs(std::vector inputs) { + TF_LITE_ENSURE_OK(&context_, + CheckTensorIndices("inputs", inputs.data(), inputs.size())); + inputs_ = std::move(inputs); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::SetOutputs(std::vector outputs) { + TF_LITE_ENSURE_OK( + &context_, CheckTensorIndices("outputs", outputs.data(), outputs.size())); + outputs_ = std::move(outputs); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::CheckTensorIndices(const char* label, + const int* indices, int length) { + // Making sure kOptionalTensor is not re-defined to something other than -1. + static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1"); + + for (int i = 0; i < length; i++) { + int index = indices[i]; + if (index < kOptionalTensor || index >= context_.tensors_size) { + ReportError(&context_, "Invalid tensor index %d in %s\n", index, label); + consistent_ = false; + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, + int dims_size, size_t* bytes) { + // TODO(aselle): Check for overflow here using overflow.h in TensorFlow + // MultiplyWithoutOverflow. + TF_LITE_ENSURE(&context_, bytes != nullptr); + size_t count = 1; + for (int k = 0; k < dims_size; k++) count *= dims[k]; + switch (type) { + case kTfLiteFloat32: + *bytes = sizeof(float) * count; + break; + case kTfLiteInt32: + *bytes = sizeof(int32_t) * count; + break; + case kTfLiteUInt8: + *bytes = sizeof(uint8_t) * count; + break; + case kTfLiteInt64: + *bytes = sizeof(int64_t) * count; + break; + default: + ReportError(&context_, + "Only float32, int32, int64, uint8 supported currently."); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus Interpreter::AllocateTensorsWhoseSizesAreKnown() { + if (!consistent_) { + ReportError(&context_, "AllocateTensors() called on inconsistent model."); + return kTfLiteError; + } + if (next_allocate_node_id_ == nodes_and_registration_.size() && invokable_) { + return kTfLiteOk; + } + allocs_and_refcounts_.resize(context_.tensors_size); + + int new_next_allocate_node_id = next_allocate_node_id_; + invokable_ = false; + + // Allocate graph input nodes. + if (next_allocate_node_id_ == 0) { + for (int i = 0; i < inputs_.size(); ++i) { + int tensor_index = inputs_[i]; + if (tensor_index == kOptionalTensor) { + continue; + } + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + // Add 1 to output tensors, so they will not get overwritten. + for (int i = 0; i < outputs_.size(); ++i) { + allocs_and_refcounts_[outputs_[i]].count++; + } + } + + // Count references to node input tensors, and resize node-referenced tensors + // until we encounter a node that has a dynamic output tensor. + for (int k = next_allocate_node_id_; k < nodes_and_registration_.size(); + k++) { + new_next_allocate_node_id++; + TfLiteNode& node = nodes_and_registration_[k].first; + const TfLiteRegistration& registration = nodes_and_registration_[k].second; + if (OpPrepare(registration, &node) == kTfLiteError) { + return kTfLiteError; + } + + TfLiteIntArray* node_inputs = node.inputs; + for (int i = 0; i < node_inputs->size; ++i) { + int tensor_index = node_inputs->data[i]; + if (tensor_index != kOptionalTensor) { + allocs_and_refcounts_[node_inputs->data[i]].count++; + } + } + + // Discontinue if the node has dynamic outputs. + bool has_unallocated_dynamic_tensor = false; + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + TfLiteTensor& tensor = context_.tensors[node_outputs->data[i]]; + if (tensor.allocation_type == kTfLiteDynamic) { + has_unallocated_dynamic_tensor = true; + break; + } + } + if (has_unallocated_dynamic_tensor) { + break; + } + } + + // Allocate graph persistent outputs, e.g. RNN cell states, etc. + for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { + TfLiteNode& node = nodes_and_registration_[k].first; + + // Go through output tensors and allocate the persistent ones first. + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + int tensor_index = node_outputs->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_OK(&context_, + persistent_arena_.Allocate( + &context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + } + + // Go through the graph in execution order. + for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { + TfLiteNode& node = nodes_and_registration_[k].first; + + // First allocate output tensors. + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + int tensor_index = node_outputs->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + // Then the temporaries, in two passes. First allocate them all, them + // deallocate them. + TfLiteIntArray* node_temporaries = node.temporaries; + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + allocs_and_refcounts_[tensor_index].count--; + if (tensor.allocation_type == kTfLiteArenaRw && + allocs_and_refcounts_[tensor_index].count == 0) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Deallocate(&context_, + allocs_and_refcounts_[tensor_index].alloc)); + } + } + + // Then process the node's inputs. + TfLiteIntArray* node_inputs = node.inputs; + for (int i = 0; i < node_inputs->size; ++i) { + int tensor_index = node_inputs->data[i]; + if (tensor_index == kOptionalTensor) { + continue; + } + TfLiteTensor& tensor = context_.tensors[tensor_index]; + + // Decrease reference count and deallocate if not needed anymore. + allocs_and_refcounts_[tensor_index].count--; + if (tensor.allocation_type == kTfLiteArenaRw && + allocs_and_refcounts_[tensor_index].count == 0) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Deallocate(&context_, + allocs_and_refcounts_[tensor_index].alloc)); + } + } + } + + // Resize the buffer and commit the arena. + TF_LITE_ENSURE_OK(&context_, arena_.Commit(&context_)); + TF_LITE_ENSURE_OK(&context_, persistent_arena_.Commit(&context_)); + + // Rewire the tensors to use the underlying arena buffer. + for (int i = 0; i < context_.tensors_size; ++i) { + TfLiteTensor& tensor = context_.tensors[i]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.ResolveAlloc(&context_, allocs_and_refcounts_[i].alloc, + &tensor.data.raw)); + } + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_OK( + &context_, + persistent_arena_.ResolveAlloc( + &context_, allocs_and_refcounts_[i].alloc, &tensor.data.raw)); + } + } + + invokable_ = true; + next_allocate_node_id_ = new_next_allocate_node_id; + return kTfLiteOk; +} + +namespace { +TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector& x) { + TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); + for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i]; + return lite; +} +} // namespace + +TfLiteStatus Interpreter::AllocateTensors() { + next_allocate_node_id_ = 0; + TF_LITE_ENSURE_OK(&context_, arena_.Clear()); + TF_LITE_ENSURE_OK(&context_, persistent_arena_.Clear()); + allocs_and_refcounts_.clear(); + return AllocateTensorsWhoseSizesAreKnown(); +} + +TfLiteStatus Interpreter::AddNodeWithParameters( + const std::vector& inputs, const std::vector& outputs, + const char* init_data, size_t init_data_size, void* builtin_data, + const TfLiteRegistration* registration, int* node_index) { + invokable_ = false; + + std::unique_ptr builtin_data_deleter(builtin_data, + free); + + TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("node inputs", inputs.data(), + inputs.size())); + TF_LITE_ENSURE_OK( + &context_, + CheckTensorIndices("node outputs", outputs.data(), outputs.size())); + + if (node_index) *node_index = nodes_and_registration_.size(); + nodes_and_registration_.resize(nodes_and_registration_.size() + 1); + auto& node_and_reg = nodes_and_registration_.back(); + TfLiteNode& node = node_and_reg.first; + if (node.inputs) TfLiteIntArrayFree(node.inputs); + if (node.outputs) TfLiteIntArrayFree(node.outputs); + if (node.temporaries) TfLiteIntArrayFree(node.temporaries); + + // NOTE, here we are not using move semantics yet, since our internal + // representation isn't std::vector, but in the future we would like to avoid + // copies, so we want the interface to take r-value references now. + node.inputs = convertVectorToTfLiteIntArray(inputs); + node.outputs = convertVectorToTfLiteIntArray(outputs); + node.temporaries = TfLiteIntArrayCreate(0); + if (init_data) { + node.user_data = OpInit(*registration, init_data, init_data_size); + } else { + node.user_data = + OpInit(*registration, + reinterpret_cast(builtin_data_deleter.get()), 0); + } + node.builtin_data = builtin_data_deleter.release(); + node_and_reg.second = *registration; + return kTfLiteOk; +} + +TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, + const std::vector& dims) { + // TODO(aselle): All bounds checks can be implemented as one-sided bounds + // checks by casting to unsigned for efficiency. Profile before doing this. + + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + invokable_ = false; + TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims); + return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); +} + +TfLiteStatus Interpreter::Invoke() { + if (!consistent_) { + ReportError(&context_, "Invoke called on model that is not consistent."); + return kTfLiteError; + } + if (!invokable_) { + ReportError(&context_, "Invoke called on model that is not ready."); + return kTfLiteError; + } + + TfLiteStatus status = kTfLiteOk; + if (nnapi_delegate_) { + if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { + return kTfLiteError; + } + if (next_allocate_node_id_ == nodes_and_registration_.size()) { + TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); + return kTfLiteOk; + } else { + // TODO(aselle): In the future, we would like this to be an + // automatic tflite CPU fallback. + ReportError(&context_, + "NNAPI was requested, but dependent sized tensors " + "being used.\n"); + return kTfLiteError; + } + } + + for (int i = 0; i < nodes_and_registration_.size(); i++) { + // Ensure we have allocated up to this node. The point of this is to + // allocate as much as possible before running any evaluation, but + // dynamic shapes can prevent this from being possible. + if (i >= next_allocate_node_id_) { + if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { + return kTfLiteError; + } + } + TfLiteNode& node = nodes_and_registration_[i].first; + const TfLiteRegistration& registration = nodes_and_registration_[i].second; + if (OpInvoke(registration, &node) == kTfLiteError) { + status = kTfLiteError; + } + } + return status; +} + +TfLiteStatus Interpreter::ResizeTensor(TfLiteContext* context, + TfLiteTensor* tensor, + TfLiteIntArray* new_size) { + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function ResizeTensorImpl + // (this function is static). + return static_cast(context->impl_) + ->ResizeTensorImpl(tensor, new_size); +} + +void Interpreter::ReportErrorImpl(const char* format, va_list args) { + error_reporter_->Report(format, args); +} + +void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) { + va_list args; + va_start(args, format); + auto* f = static_cast(context->impl_); + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function ReportErrorImpl + // (this function is static). + f->ReportErrorImpl(format, args); + va_end(args); +} + +TfLiteStatus Interpreter::AddTensors(int tensors_to_add, + int* first_new_tensor_index) { + int base_index = tensors_.size(); + if (first_new_tensor_index) *first_new_tensor_index = base_index; + tensors_.resize(tensors_.size() + tensors_to_add); + for (int i = base_index; i < tensors_.size(); i++) { + memset(&tensors_[i], 0, sizeof(tensors_[i])); + } + context_.tensors = tensors_.data(); + context_.tensors_size = tensors_.size(); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add, + int* first_new_tensor_index) { + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function AddTensors + // (this function is static). + return static_cast(context->impl_) + ->AddTensors(tensors_to_add, first_new_tensor_index); +} + +TfLiteStatus Interpreter::SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization, + const char* buffer, size_t bytes, const Allocation* allocation) { + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + // For most tensors we know exactly how much memory is necessary so we can + // ensure the buffer is large enough. However, we need to skip string tensors + // because their sizes change with the contents of the individual strings. + if (type != kTfLiteString) { + size_t required_bytes; + TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(), + &required_bytes)); + TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes); + } + invokable_ = false; + TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), + quantization, const_cast(buffer), bytes, + kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]); + return kTfLiteOk; +} + +// Set description of inputs/outputs/data/fptrs for node `node_index`. +// This variant assumes an external buffer has been allocated of size +// bytes. The lifetime of buffer must be ensured to be greater or equal +// to Interpreter. +TfLiteStatus Interpreter::SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization) { + invokable_ = false; + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + size_t required_bytes = 0; + if (type != kTfLiteString) { + // These types will be allocated in our arena so we need to record how + // many bytes we will need based on the dimensions. String tensors are + // allocated dynamically and we can't know ahead of time how much space + // they will require. + TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(), + &required_bytes)); + } + TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), + quantization, + /*buffer=*/nullptr, required_bytes, + type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, + nullptr, &context_.tensors[tensor_index]); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, + TfLiteIntArray* new_size) { + // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. + if (tensor->allocation_type == kTfLiteArenaRw || + tensor->allocation_type == kTfLiteDynamic) { + if (tensor->type != kTfLiteString) { + size_t bytesRequired; + TfLiteStatus status = BytesRequired(tensor->type, new_size->data, + new_size->size, &bytesRequired); + if (status != kTfLiteOk) { + TfLiteIntArrayFree(new_size); + return kTfLiteError; + } + tensor->bytes = bytesRequired; + } + if (tensor->dims) TfLiteIntArrayFree(tensor->dims); + tensor->dims = new_size; + + if (tensor->allocation_type != kTfLiteDynamic) { + tensor->data.raw = nullptr; + } + } else { + // kTfLiteMmapRo tensors are stored in the flatbuffer and are therefore + // of fixed size. + TfLiteIntArrayFree(new_size); + ReportError(&context_, "Attempting to resize a fixed-size tensor."); + return kTfLiteError; + } + return kTfLiteOk; +} + +void Interpreter::UseNNAPI(bool enable) { + // TODO(aselle): This is a workaround for finding if NNAPI exists. + // We also need to make sure getLibraryHandle() is renamed to be NNAPI + // prefixed. + if (!NNAPIExists()) enable = false; + if (!enable) { + nnapi_delegate_.reset(); + } else if (!nnapi_delegate_) { + nnapi_delegate_.reset(new NNAPIDelegate); + } +} + +void Interpreter::SetNumThreads(int num_threads) { + // TODO(ahentz): this forces us to link against gemmlowp even when the ops + // don't use it. We should implement some dynamic mechanism for this sort of + // library-specific initialization. + tflite::gemm_support::SetMaxNumThreads(&context_, num_threads); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..65c61e44bee48535f884a3afaddc691972f5e04b --- /dev/null +++ b/tensorflow/contrib/lite/interpreter.h @@ -0,0 +1,374 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Main abstraction controlling the tflite interpreter. +// See context.h for the API for defining operations (TfLiteRegistration). +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +namespace tflite { + +// Map statically from a c++ type to a TfLiteType (used below for safe casts). +template +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteNoType; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt32; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt64; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteFloat32; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteUInt8; +} + +struct ArenaAllocRefCount { + ArenaAllocRefCount() : alloc(), count(0) {} + + ArenaAlloc alloc; + int count; +}; + +// Forward declare since NNAPIDelegate uses Interpreter. +class NNAPIDelegate; + +// An interpreter for a graph of nodes that input and output from tensors. +// Each node of the graph processes a set of input tensors and produces a +// set of output Tensors. All inputs/output tensors are referenced by index. +// +// Usage: +// +// -- Create basic model +// Interpreter foo(2, 1); +// foo.SetTensorParametersReadWrite(0, ...); +// foo.SetTensorParametersReadOnly(1, ...); +// foo.SetNodeParameters(0, ...) +// +// -- Resize input array to 1 length. +// foo.ResizeInputTensor(0, 1); +// foo.AllocateTensors(); +// -- Install array data +// foo.typed_tensor(0)[0] = 3; +// foo.Invoke(); +// foo.typed_tensor(0)[0] = 4; +// foo.Invoke(); +// -- Resize input array and set data. +// foo.ResizeInputTensor(0, 2); +// foo.AllocateTensors(); +// foo.typed_tensor(0)[0] = 4; +// foo.typed_tensor(0)[1] = 8; +// foo.Invoke(); +// + +class Interpreter { + public: + // Instantiate an interpreter. All errors associated with reading and + // processing this model will be forwarded to the error_reporter object. + // + // Note, if error_reporter is nullptr, then a default StderrReporter is + // used. + explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); + + ~Interpreter(); + + Interpreter(const Interpreter&) = delete; + Interpreter& operator=(const Interpreter&) = delete; + + // Functions to build interpreter + + // Provide a list of tensor indexes that are inputs to the model. + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetInputs(std::vector inputs); + + // Provide a list of tensor indexes that are outputs to the model + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetOutputs(std::vector outputs); + + // Adds a node with the given parameters and returns the index of the new + // node in `node_index` (optionally). Interpreter will take ownership of + // `builtin_data` and destroy it with `delete`. Ownership of 'init_data' + // remains with the caller. + TfLiteStatus AddNodeWithParameters(const std::vector& inputs, + const std::vector& outputs, + const char* init_data, + size_t init_data_size, void* builtin_data, + const TfLiteRegistration* registration, + int* node_index = nullptr); + + // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries. + // The value pointed to by `first_new_tensor_index` will be set to the + // index of the first new tensor if `first_new_tensor_index` is non-null. + TfLiteStatus AddTensors(int tensors_to_add, + int* first_new_tensor_index = nullptr); + + // Set description of inputs/outputs/data/fptrs for node `node_index`. + // This variant assumes an external buffer has been allocated of size + // bytes. The lifetime of buffer must be ensured to be greater or equal + // to Interpreter. + TfLiteStatus SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization, + const char* buffer, size_t bytes, const Allocation* allocation = nullptr); + + // Set description of inputs/outputs/data/fptrs for node `node_index`. + // This variant assumes an external buffer has been allocated of size + // bytes. The lifetime of buffer must be ensured to be greater or equal + // to Interpreter. + TfLiteStatus SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization); + + // Functions to access tensor data + + // Read only access to list of inputs. + const std::vector& inputs() const { return inputs_; } + + // Return the name of a given input. The given index must be between 0 and + // inputs().size(). + const char* GetInputName(int index) const { + return context_.tensors[inputs_[index]].name; + } + + // Read only access to list of outputs. + const std::vector& outputs() const { return outputs_; } + + // Return the name of a given output. The given index must be between 0 and + // outputs().size(). + const char* GetOutputName(int index) const { + return context_.tensors[outputs_[index]].name; + } + + // Return the number of tensors in the model. + int tensors_size() const { return context_.tensors_size; } + + // Return the number of ops in the model. + int nodes_size() const { return nodes_and_registration_.size(); } + + // Get a tensor data structure. + // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this + // read/write access to structure + TfLiteTensor* tensor(int tensor_index) { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + + // Get a pointer to an operation and registration data structure if in bounds. + // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this + // read/write access to structure + const std::pair* node_and_registration( + int node_index) { + if (node_index >= nodes_and_registration_.size() || node_index < 0) + return nullptr; + return &nodes_and_registration_[node_index]; + } + + // Perform a checked cast to the appropriate tensor type. + template + T* typed_tensor(int tensor_index) { + if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType()) { + return reinterpret_cast(tensor_ptr->data.raw); + } + } + return nullptr; + } + + // Return a pointer into the data of a given input tensor. The given index + // must be between 0 and inputs().size(). + template + T* typed_input_tensor(int index) { + return typed_tensor(inputs_[index]); + } + + // Return a pointer into the data of a given output tensor. The given index + // must be between 0 and outputs().size(). + template + T* typed_output_tensor(int index) { + return typed_tensor(outputs_[index]); + } + + // Change the dimensionality of a given tensor. Note, this is only acceptable + // for tensor indices that are inputs. + // Returns status of failure or success. + // TODO(aselle): Consider implementing ArraySlice equivalent to make this + // more adept at accepting data without an extra copy. Use absl::ArraySlice + // if our partners determine that dependency is acceptable. + TfLiteStatus ResizeInputTensor(int tensor_index, + const std::vector& dims); + + // Update allocations for all tensors. This will redim dependent tensors using + // the input tensor dimensionality as given. This is relatively expensive. + // If you know that your sizes are not changing, you need not call this. + + // Returns status of success or failure. + TfLiteStatus AllocateTensors(); + + // Invoke the interpreter (run the whole graph in dependency order). + // + // NOTE: It is possible that the interpreter is not in a ready state + // to evaluate (i.e. if a ResizeTensor() has been performed without an + // AllocateTensors(). + // Returns status of success or failure. + TfLiteStatus Invoke(); + + // Enable or disable the NN API (true to enable) + void UseNNAPI(bool enable); + + // Set the number of threads available to the interpreter. + void SetNumThreads(int num_threads); + + private: + // Give 'op_reg' a chance to initialize itself using the contents of + // 'buffer'. + void* OpInit(const TfLiteRegistration& op_reg, const char* buffer, + size_t length) { + if (op_reg.init == nullptr) return nullptr; + return op_reg.init(&context_, buffer, length); + } + + // Let 'op_reg' release any memory it might have allocated via 'OpInit'. + void OpFree(const TfLiteRegistration& op_reg, void* buffer) { + if (op_reg.free == nullptr) return; + if (buffer) { + op_reg.free(&context_, buffer); + } + } + + // Prepare the given 'node' for execution. + TfLiteStatus OpPrepare(const TfLiteRegistration& op_reg, TfLiteNode* node) { + if (op_reg.prepare == nullptr) return kTfLiteOk; + return op_reg.prepare(&context_, node); + } + + // Invoke the operator represented by 'node'. + TfLiteStatus OpInvoke(const TfLiteRegistration& op_reg, TfLiteNode* node) { + if (op_reg.invoke == nullptr) return kTfLiteError; + return op_reg.invoke(&context_, node); + } + + // Allocate tensors whose sizes are known in order of nodes. Discontinue when + // we encounter a node that has a dynamic output tensor. + TfLiteStatus AllocateTensorsWhoseSizesAreKnown(); + + // Tensors needed by the interpreter. Use `AddTensors` to add more blank + // tensor entries. Note, `tensors_.data()` needs to be synchronized to the + // `context_` whenever this std::vector is reallocated. Currently this + // only happens in `AddTensors()`. + std::vector tensors_; + + // Check if an array of tensor indices are valid with respect to the Tensor + // array. + // NOTE: this changes consistent_ to be false if indices are out of bounds. + TfLiteStatus CheckTensorIndices(const char* label, const int* indices, + int length); + + // Compute the number of bytes required to represent a tensor with dimensions + // specified by the array dims (of length dims_size). Returns the status code + // and bytes. + TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size, + size_t* bytes); + + // Request an tensor be resized implementation. + TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size); + + // Report a detailed error string (will be printed to stderr). + // TODO(aselle): allow user of class to provide alternative destinations. + void ReportErrorImpl(const char* format, va_list args); + + // Entry point for C node plugin API to request an tensor be resized. + static TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor, + TfLiteIntArray* new_size); + // Entry point for C node plugin API to report an error. + static void ReportError(TfLiteContext* context, const char* format, ...); + + // Entry point for C node plugin API to add new tensors. + static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add, + int* first_new_tensor_index); + + // A pure C data structure used to communicate with the pure C plugin + // interface. To avoid copying tensor metadata, this is also the definitive + // structure to store tensors. + TfLiteContext context_; + + // Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores + // function pointers to actual implementation. + std::vector> + nodes_and_registration_; + + // Raw memory buffer that is allocated for all temporary and graph outputs. + // that are declared kTfLiteArenaRw. + SimpleMemoryArena arena_; + + // Raw memory buffer that is allocated for persistent tensors that are + // declared as kTfLiteArenaRwPersistent. + SimpleMemoryArena persistent_arena_; + + // Stores allocation and reference counts of all tensors. + std::vector allocs_and_refcounts_; + + // Whether the model is consistent. That is to say if the inputs and outputs + // of every node and the global inputs and outputs are valid indexes into + // the tensor array. + bool consistent_ = true; + + // Whether the model is safe to invoke (if any errors occurred this + // will be false). + bool invokable_ = false; + + // Array of indices representing the tensors that are inputs to the + // interpreter. + std::vector inputs_; + + // Array of indices representing the tensors that are outputs to the + // interpreter. + std::vector outputs_; + + // The error reporter delegate that tflite will forward queries errors to. + ErrorReporter* error_reporter_; + + // Next node to allocate output tensors. + // During Invoke(), Interpreter will allocate input tensors first, which are + // known to be fixed size. Then it will allocate outputs from nodes as many + // as possible. When there is a node that produces dynamic sized tensor. + // Intepreter will stop allocating tensors, set the value of next allocate + // node id, and execute the node to generate the output tensor before continue + // to allocate successors. This process repeats until all nodes are executed. + // NOTE: this relies on the order of nodes that is in topological order. + int next_allocate_node_id_; + + // Whether to delegate to NN API + std::unique_ptr nnapi_delegate_; +}; + +} // namespace tflite +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..edff2109430c6e1ec6c481619ed7772237a3301d --- /dev/null +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -0,0 +1,526 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/interpreter.h" +#include +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +// Make an interpreter that has no tensors and no nodes +TEST(BasicInterpreter, ZeroInterpreter) { + Interpreter interpreter; + interpreter.SetInputs({}); + interpreter.SetOutputs({}); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Test various error conditions. +TEST(BasicInterpreter, InvokeInvalidModel) { + Interpreter interpreter; + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Test size accesser functions. +TEST(BasicInterpreter, TestSizeFunctions) { + Interpreter interpreter; + int base_index; + ASSERT_EQ(interpreter.nodes_size(), 0); + ASSERT_EQ(interpreter.tensors_size(), 0); + ASSERT_EQ(interpreter.AddTensors(2, &base_index), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 2); + ASSERT_EQ(base_index, 0); + ASSERT_EQ(interpreter.AddTensors(3, &base_index), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 5); + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 6); + ASSERT_EQ(base_index, 2); +} + +// Test if invalid indices make a model inconsistent (and conversely if +// valid indices keep a model consistent). +TEST(BasicInterpreter, InconsistentModel) { + // Invalid inputs + { + Interpreter interpreter; + ASSERT_NE(interpreter.SetInputs({5}), kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.inputs(), std::vector()); + } + // Invalid outputs + { + Interpreter interpreter; + ASSERT_NE(interpreter.SetOutputs({5}), kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.outputs(), std::vector()); + } + // Invalid node inputs + { + Interpreter interpreter; + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_NE(interpreter.AddNodeWithParameters({3}, {0}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + } + // Valid inputs and outputs and a node with valid inputs and outputs + { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + } +} + +// Make an interpreter that has one tensor but no ops +TEST(BasicInterpreter, CheckAllocate) { + struct { + TfLiteType type; + size_t size; + } cases[] = { + {kTfLiteFloat32, sizeof(float)}, + {kTfLiteInt32, sizeof(int32_t)}, + {kTfLiteUInt8, sizeof(uint8_t)}, + {kTfLiteInt64, sizeof(int64_t)}, + }; + + for (auto test : cases) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({}); + TfLiteQuantizationParams quant; + + interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant); + interpreter.SetTensorParametersReadWrite(1, test.type, "", {4}, quant); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.tensor(0)->bytes, 3 * test.size); + ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(1)->bytes, 4 * test.size); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); + } +} + +TEST(BasicInterpreter, CheckResize) { + const float floats[] = {-3., -4.}; + const int32_t int32s[] = {-3, -4}; + const uint8_t uint8s[] = {3, 4}; + const int64_t int64s[] = {6, -7}; + + struct { + TfLiteType type; + size_t size; + const char* array; + } cases[] = { + {kTfLiteFloat32, sizeof(float), reinterpret_cast(floats)}, + {kTfLiteInt32, sizeof(int32_t), reinterpret_cast(int32s)}, + {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast(uint8s)}, + {kTfLiteInt64, sizeof(int64_t), reinterpret_cast(int64s)}, + }; + + for (auto test : cases) { + Interpreter interpreter; + + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({}); + TfLiteQuantizationParams quant; + + ASSERT_EQ( + interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadOnly( + 1, test.type, "", {2}, quant, test.array, 2 * test.size), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {1, 2}), kTfLiteOk); + // Resizing a mmapped tensor is not allowed and should produce error. + ASSERT_NE(interpreter.ResizeInputTensor(1, {3}), kTfLiteOk); + // Set the tensor to be mmapped but with a buffer size that is insufficient + // to match the dimensionality. + ASSERT_NE(interpreter.SetTensorParametersReadOnly( + 1, test.type, "", {2}, quant, test.array, 1 * test.size), + kTfLiteOk); + // Allocating should work since we should have our last correct array + // values in place. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + } +} + +TEST(BasicInterpreter, CheckAlignment) { + struct { + TfLiteType type; + } cases[] = { + {kTfLiteFloat32}, + {kTfLiteInt32}, + {kTfLiteUInt8}, + {kTfLiteInt64}, + }; + + for (auto test : cases) { + Interpreter interpreter; + + ASSERT_EQ(interpreter.AddTensors(4), kTfLiteOk); + + for (int i = 0; i < 4; i++) { + TfLiteQuantizationParams quant; + interpreter.SetTensorParametersReadWrite(i, test.type, "", {2 * i + 1}, + quant); + } + interpreter.AllocateTensors(); + for (int i = 0; i < 4; i++) { + const TfLiteTensor& tensor = *interpreter.tensor(i); + ASSERT_EQ(reinterpret_cast(tensor.data.raw) % 4, 0); + } + } +} + +TEST(BasicInterpreter, CheckArenaAllocation) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(10), kTfLiteOk); + + TfLiteQuantizationParams quant; + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + std::vector sizes{2048, 4096, 1023, 2047, 1021, + 2047, 1023, 2046, 1021, 2048}; + for (int i = 0; i < sizes.size(); ++i) { + interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]}, + quant); + } + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({9, 4}); + interpreter.AddNodeWithParameters({0, 1}, {2, 3}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({2, 1}, {4, 5}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({4, 3}, {6, 7}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({6, 5}, {8}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({8, 7}, {9}, nullptr, 0, nullptr, ®); + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw); + ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw); + + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw); + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw); + + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw); + + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); +} + +TEST(BasicInterpreter, BufferAccess) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + // Verify we get a valid pointer.r + ASSERT_NE(interpreter.typed_tensor(0), nullptr); + // Verify incorrect pointer will not returned. + ASSERT_EQ(interpreter.typed_tensor(0), nullptr); + // Verify that raw c interface ptr matches safe interface. + ASSERT_EQ(interpreter.typed_tensor(0), interpreter.tensor(0)->data.f); +} + +TEST(BasicInterpreter, NoOpInterpreter) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +TEST(BasicInterpreter, OneOpInterpreter) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "in1", + {3}, quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "out0", + {3}, quantized), + kTfLiteOk); + + ASSERT_EQ(interpreter.GetInputName(0), "in1"); + ASSERT_EQ(interpreter.GetOutputName(0), "out0"); + + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.init = [](TfLiteContext* context, const char*, size_t) -> void* { + auto* first_new_tensor = new int; + context->AddTensors(context, 2, first_new_tensor); + return first_new_tensor; + }; + reg.free = [](TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); + }; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + auto* first_new_tensor = reinterpret_cast(node->user_data); + + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]]; + + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor1, newSize)); + + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(2); + for (int i = 0; i < 2; ++i) { + node->temporaries->data[i] = *(first_new_tensor) + i; + } + + auto setup_temporary = [&](int id) { + TfLiteTensor* tmp = &context->tensors[id]; + tmp->type = kTfLiteFloat32; + tmp->allocation_type = kTfLiteArenaRw; + return context->ResizeTensor(context, tmp, + TfLiteIntArrayCopy(tensor0->dims)); + }; + TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[0])); + TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[1])); + + return kTfLiteOk; + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + + auto populate = [&](int id) { + TfLiteTensor* t = &context->tensors[id]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + t->data.f[i] = a0->data.f[i]; + } + }; + + populate(node->outputs->data[0]); + populate(node->temporaries->data[0]); + populate(node->temporaries->data[1]); + return kTfLiteOk; + }; + ASSERT_EQ( + interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®), + kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Forcefully divides tensor allocation in three steps: one before invocation +// and two more at invocation time. This happens because we use string tensors +// and their sizes can't be determined until invocation time. +TEST(BasicInterpreter, ThreeStepAllocate) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(5), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({4}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'A', 'B', 'C'}; + // Read only string tensor. + ASSERT_EQ(interpreter.SetTensorParametersReadOnly(0, kTfLiteString, "", {1}, + quantized, data, 15), + kTfLiteOk); + // Read-write string tensor. + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteString, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(2, kTfLiteInt32, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(3, kTfLiteString, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(4, kTfLiteInt32, "", {1}, + quantized), + kTfLiteOk); + + // String-in String-out node. + TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr}; + reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + DynamicBuffer buf; + StringRef str_ref = GetString(a0, 0); + buf.AddString(str_ref); + buf.WriteToTensor(a1); + return kTfLiteOk; + }; + + // String-in Int-out node. + TfLiteRegistration reg_len = {nullptr, nullptr, nullptr, nullptr}; + reg_len.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); + outputSize->data[0] = 1; + return context->ResizeTensor(context, output, outputSize); + }; + reg_len.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + a1->data.i32[0] = a0->bytes; + return kTfLiteOk; + }; + + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®_copy), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr, + ®_len), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {3}, nullptr, 0, nullptr, + ®_copy), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr, + ®_len), + kTfLiteOk); + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + ASSERT_EQ(interpreter.tensor(0)->bytes, 15); + ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(1)->bytes, 15); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(3)->bytes, 15); + ASSERT_NE(interpreter.tensor(4)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(2)->bytes, 4); + ASSERT_EQ(interpreter.tensor(2)->data.i32[0], 15); + ASSERT_EQ(interpreter.tensor(4)->bytes, 4); + ASSERT_EQ(interpreter.tensor(4)->data.i32[0], 15); +} + +TEST(BasicInterpreter, AllocateTwice) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quantized), + kTfLiteOk); + + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + return context->ResizeTensor(context, tensor1, newSize); + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + a1->data.f[i] = a0->data.f[i]; + } + return kTfLiteOk; + }; + ASSERT_EQ( + interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®), + kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + char* old_tensor0_ptr = interpreter.tensor(0)->data.raw; + char* old_tensor1_ptr = interpreter.tensor(1)->data.raw; + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(old_tensor0_ptr, interpreter.tensor(0)->data.raw); + ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw); +} + +struct TestErrorReporter : public ErrorReporter { + int Report(const char* format, va_list args) override { + char buffer[1024]; + int size = vsnprintf(buffer, sizeof(buffer), format, args); + all_reports += buffer; + calls++; + return size; + } + int calls = 0; + std::string all_reports; +}; + +TEST(BasicInterpreter, TestNullErrorReporter) { + TestErrorReporter reporter; + Interpreter interpreter; +} + +TEST(BasicInterpreter, TestCustomErrorReporter) { + TestErrorReporter reporter; + Interpreter interpreter(&reporter); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(reporter.all_reports, "Invoke called on model that is not ready."); + ASSERT_EQ(reporter.calls, 1); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { +#ifdef OS_LINUX + FLAGS_logtostderr = true; +#endif + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/ios_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..bcff7ed9889e95c13294b6cf0d0f4788991a04df --- /dev/null +++ b/tensorflow/contrib/lite/ios_makefile.inc @@ -0,0 +1,47 @@ +# Settings for iOS. +ifeq ($(TARGET), IOS) + BUILD_FOR_IOS_SIMULATOR := false + ifeq ($(IOS_ARCH), x86_64) + BUILD_FOR_IOS_SIMULATOR := true + endif + ifeq ($(IOS_ARCH), i386) + BUILD_FOR_IOS_SIMULATOR := true + endif + ifeq ($(BUILD_FOR_IOS_SIMULATOR), true) + IPHONEOS_PLATFORM := $(shell xcrun --sdk iphonesimulator \ + --show-sdk-platform-path) + IPHONEOS_SYSROOT := $(shell xcrun --sdk iphonesimulator \ + --show-sdk-path) + else + IPHONEOS_PLATFORM := $(shell xcrun --sdk iphoneos --show-sdk-platform-path) + IPHONEOS_SYSROOT := $(shell xcrun --sdk iphoneos --show-sdk-path) + endif + IOS_SDK_VERSION := $(shell xcrun --sdk iphoneos --show-sdk-version) + MIN_SDK_VERSION := 9.0 + # Override IOS_ARCH with armv7, armv7s, arm64, i386, or x86_64. + IOS_ARCH := x86_64 + CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -fembed-bitcode \ + -Wno-c++11-narrowing \ + -mno-thumb \ + -fno-exceptions \ + -isysroot \ + ${IPHONEOS_SYSROOT} \ + -arch $(IOS_ARCH) \ + -O3 + CCFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ + -fembed-bitcode \ + -mno-thumb \ + -isysroot \ + ${IPHONEOS_SYSROOT} \ + -arch $(IOS_ARCH) \ + -O3 + LDFLAGS := -fembed-bitcode \ + -miphoneos-version-min=${MIN_SDK_VERSION} \ + -arch $(IOS_ARCH) + OBJDIR := $(OBJDIR)ios_$(IOS_ARCH)/ + LIBDIR := $(LIBDIR)ios_$(IOS_ARCH)/ + BINDIR := $(BINDIR)ios_$(IOS_ARCH)/ + DEPDIR := $(DEPDIR)ios_$(IOS_ARCH)/ +endif diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..1de28eb52ddb458df0be0a8f9ef453f7caf68654 --- /dev/null +++ b/tensorflow/contrib/lite/java/BUILD @@ -0,0 +1,150 @@ +# Description: +# TensorFlow Lite Java API. + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary") + +android_library( + name = "tensorflowlite", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + visibility = ["//visibility:public"], + deps = [ + ":tflite_runtime", + "@javax_validation", + ], +) + +android_library( + name = "tensorflowlite_java", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + visibility = ["//visibility:public"], + deps = [ + "@javax_validation", + ], +) + +java_library( + name = "tensorflowlitelib", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + javacopts = JAVACOPTS, + visibility = ["//visibility:public"], + deps = [ + ":libtensorflowlite_jni.so", + "//tensorflow/contrib/lite/java/src/main/native", + "@javax_validation", + ], +) + +java_test( + name = "TensorFlowLiteTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.TensorFlowLiteTest", + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "DataTypeTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/DataTypeTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.DataTypeTest", + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "NativeInterpreterWrapperTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java"], + data = [ + "src/testdata/add.bin", + "src/testdata/int32.bin", + "src/testdata/int64.bin", + "src/testdata/invalid_model.bin", + "src/testdata/uint8.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "TensorTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/TensorTest.java"], + data = [ + "src/testdata/add.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.TensorTest", + deps = [ + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +filegroup( + name = "libtensorflowlite_jni", + srcs = select({ + "//conditions:default": [":libtensorflowlite_jni.so"], + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "tflite_runtime", + srcs = ["libtensorflowlite_jni.so"], + visibility = ["//visibility:public"], +) + +tflite_jni_binary( + name = "libtensorflowlite_jni.so", + deps = [ + "//tensorflow/contrib/lite/java/src/main/native", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/.gitignore b/tensorflow/contrib/lite/java/demo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..39fb081a42a86ccf8f9cf99dbccc8bdf7c828bce --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/.gitignore @@ -0,0 +1,9 @@ +*.iml +.gradle +/local.properties +/.idea/workspace.xml +/.idea/libraries +.DS_Store +/build +/captures +.externalNativeBuild diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2e818f728ef208d30b0eeb27ffd7e3fa0c7c1a2d --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -0,0 +1,46 @@ +# TF Lite Android App + +## Building from Source with Bazel + +1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel): + + 1. [Install Bazel and Android Prerequisites](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites). + It's easiest with Android Studio. + + - You'll need at least SDK version 23. + - Make sure to install the latest version of Bazel. Some distributions + ship with Bazel 0.5.4, which is too old. + - Bazel requires Android Build Tools `26.0.1` or higher. + - **Bazel is incompatible with NDK revisions 15 and above,** with revision + 16 being a compile-breaking change. [Download an older version manually + instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites) + - You also need to install the Android Support Repository, available + through Android Studio under `Android SDK Manager -> SDK Tools -> + Android Support Repository`. + + 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace) + to add SDK and NDK targets. + + NOTE: As long as you have the SDK and NDK installed, the `./configure` + script will create these rules for you. Answer "Yes" when the script asks + to automatically configure the `./WORKSPACE`. + + - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that + you have installed. + - By default, Android Studio will install the SDK to `~/Android/Sdk` and + the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual + download until Bazel supports NDK 16. See bullet points under (1)). + +2. Build the app with Bazel. The demo needs C++11: + + ```shell + bazel build -c opt --cxxopt='--std=c++11' \ + //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo + ``` + +3. Install the demo on a + [debug-enabled device](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install): + + ```shell + adb install bazel-bin/tensorflow/contrib/lite/java/demo/app/src/main/TfLiteCameraDemo.apk + ``` diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..b76eaad8bb91224805d16b3d6f7c3274c9feb90c --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/build.gradle @@ -0,0 +1,58 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion "26.0.1" + defaultConfig { + applicationId "android.example.com.tflitecamerademo" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'com.android.support:appcompat-v7:25.2.0' + compile 'com.android.support.constraint:constraint-layout:1.0.2' + compile 'com.android.support:design:25.2.0' + compile 'com.android.support:support-annotations:25.3.1' + compile 'com.android.support:support-v13:25.2.0' + + compile 'org.tensorflow:tensorflow-lite:+' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..ba63dce5d9a7192a2c3c4c5561333d39a3ecc024 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..654fa9d6d2799fc3cafa3e0e042cb2a5746bf2c5 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD @@ -0,0 +1,41 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +android_binary( + name = "TfLiteCameraDemo", + srcs = glob(["java/**/*.java"]), + assets = [ + "@tflite_mobilenet//:labels.txt", + "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", + ], + assets_dir = "", + custom_package = "com.example.android.tflitecamerademo", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + # In some platforms we don't have an Android SDK/NDK and this target + # can't be built. We need to prevent the build system from trying to + # use the target in that case. + tags = ["manual"], + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..dd0cd6c98ff878e9c41875cab74c12191cadb173 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files( + glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..fe811239d8e2989de19fecabb1ebb0c9dddac514 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java new file mode 100644 index 0000000000000000000000000000000000000000..f2045906599218871b51a752dcbb3eeb23b8f085 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.content.Context; +import android.util.AttributeSet; +import android.view.TextureView; + +/** A {@link TextureView} that can be adjusted to a specified aspect ratio. */ +public class AutoFitTextureView extends TextureView { + + private int mRatioWidth = 0; + private int mRatioHeight = 0; + + public AutoFitTextureView(Context context) { + this(context, null); + } + + public AutoFitTextureView(Context context, AttributeSet attrs) { + this(context, attrs, 0); + } + + public AutoFitTextureView(Context context, AttributeSet attrs, int defStyle) { + super(context, attrs, defStyle); + } + + /** + * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio + * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is, + * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. + * + * @param width Relative horizontal size + * @param height Relative vertical size + */ + public void setAspectRatio(int width, int height) { + if (width < 0 || height < 0) { + throw new IllegalArgumentException("Size cannot be negative."); + } + mRatioWidth = width; + mRatioHeight = height; + requestLayout(); + } + + @Override + protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) { + super.onMeasure(widthMeasureSpec, heightMeasureSpec); + int width = MeasureSpec.getSize(widthMeasureSpec); + int height = MeasureSpec.getSize(heightMeasureSpec); + if (0 == mRatioWidth || 0 == mRatioHeight) { + setMeasuredDimension(width, height); + } else { + if (width < height * mRatioWidth / mRatioHeight) { + setMeasuredDimension(width, width * mRatioHeight / mRatioWidth); + } else { + setMeasuredDimension(height * mRatioWidth / mRatioHeight, height); + } + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java new file mode 100644 index 0000000000000000000000000000000000000000..74737a8b883d23684220dd32bbd7a9e8ab4b2123 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -0,0 +1,708 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.app.AlertDialog; +import android.app.Dialog; +import android.app.DialogFragment; +import android.app.Fragment; +import android.content.Context; +import android.content.DialogInterface; +import android.content.pm.PackageInfo; +import android.content.pm.PackageManager; +import android.content.res.Configuration; +import android.graphics.Bitmap; +import android.graphics.ImageFormat; +import android.graphics.Matrix; +import android.graphics.Point; +import android.graphics.RectF; +import android.graphics.SurfaceTexture; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCaptureSession; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraDevice; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.CaptureRequest; +import android.hardware.camera2.CaptureResult; +import android.hardware.camera2.TotalCaptureResult; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.ImageReader; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.support.annotation.NonNull; +import android.support.v13.app.FragmentCompat; +import android.support.v4.content.ContextCompat; +import android.util.Log; +import android.util.Size; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import android.widget.TextView; +import android.widget.Toast; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** Basic fragments for the Camera. */ +public class Camera2BasicFragment extends Fragment + implements FragmentCompat.OnRequestPermissionsResultCallback { + + /** Tag for the {@link Log}. */ + private static final String TAG = "TfLiteCameraDemo"; + + private static final String FRAGMENT_DIALOG = "dialog"; + + private static final String HANDLE_THREAD_NAME = "CameraBackground"; + + private static final int PERMISSIONS_REQUEST_CODE = 1; + + private final Object lock = new Object(); + private boolean runClassifier = false; + private boolean checkedPermissions = false; + private TextView textView; + private ImageClassifier classifier; + + /** Max preview width that is guaranteed by Camera2 API */ + private static final int MAX_PREVIEW_WIDTH = 1920; + + /** Max preview height that is guaranteed by Camera2 API */ + private static final int MAX_PREVIEW_HEIGHT = 1080; + + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + + @Override + public void onSurfaceTextureAvailable(SurfaceTexture texture, int width, int height) { + openCamera(width, height); + } + + @Override + public void onSurfaceTextureSizeChanged(SurfaceTexture texture, int width, int height) { + configureTransform(width, height); + } + + @Override + public boolean onSurfaceTextureDestroyed(SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(SurfaceTexture texture) {} + }; + + /** ID of the current {@link CameraDevice}. */ + private String cameraId; + + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + + /** A {@link CameraCaptureSession } for camera preview. */ + private CameraCaptureSession captureSession; + + /** A reference to the opened {@link CameraDevice}. */ + private CameraDevice cameraDevice; + + /** The {@link android.util.Size} of camera preview. */ + private Size previewSize; + + /** {@link CameraDevice.StateCallback} is called when {@link CameraDevice} changes its state. */ + private final CameraDevice.StateCallback stateCallback = + new CameraDevice.StateCallback() { + + @Override + public void onOpened(@NonNull CameraDevice currentCameraDevice) { + // This method is called when the camera is opened. We start camera preview here. + cameraOpenCloseLock.release(); + cameraDevice = currentCameraDevice; + createCameraPreviewSession(); + } + + @Override + public void onDisconnected(@NonNull CameraDevice currentCameraDevice) { + cameraOpenCloseLock.release(); + currentCameraDevice.close(); + cameraDevice = null; + } + + @Override + public void onError(@NonNull CameraDevice currentCameraDevice, int error) { + cameraOpenCloseLock.release(); + currentCameraDevice.close(); + cameraDevice = null; + Activity activity = getActivity(); + if (null != activity) { + activity.finish(); + } + } + }; + + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + + /** A {@link Handler} for running tasks in the background. */ + private Handler backgroundHandler; + + /** An {@link ImageReader} that handles image capture. */ + private ImageReader imageReader; + + /** {@link CaptureRequest.Builder} for the camera preview */ + private CaptureRequest.Builder previewRequestBuilder; + + /** {@link CaptureRequest} generated by {@link #previewRequestBuilder} */ + private CaptureRequest previewRequest; + + /** A {@link Semaphore} to prevent the app from exiting before closing the camera. */ + private Semaphore cameraOpenCloseLock = new Semaphore(1); + + /** A {@link CameraCaptureSession.CaptureCallback} that handles events related to capture. */ + private CameraCaptureSession.CaptureCallback captureCallback = + new CameraCaptureSession.CaptureCallback() { + + @Override + public void onCaptureProgressed( + @NonNull CameraCaptureSession session, + @NonNull CaptureRequest request, + @NonNull CaptureResult partialResult) {} + + @Override + public void onCaptureCompleted( + @NonNull CameraCaptureSession session, + @NonNull CaptureRequest request, + @NonNull TotalCaptureResult result) {} + }; + + /** + * Shows a {@link Toast} on the UI thread for the classification results. + * + * @param text The message to show + */ + private void showToast(final String text) { + final Activity activity = getActivity(); + if (activity != null) { + activity.runOnUiThread( + new Runnable() { + @Override + public void run() { + textView.setText(text); + } + }); + } + } + + /** + * Resizes image. + * + * Attempting to use too large a preview size could exceed the camera bus' bandwidth limitation, + * resulting in gorgeous previews but the storage of garbage capture data. + * + * Given {@code choices} of {@code Size}s supported by a camera, choose the smallest one that is + * at least as large as the respective texture view size, and that is at most as large as the + * respective max size, and whose aspect ratio matches with the specified value. If such size + * doesn't exist, choose the largest one that is at most as large as the respective max size, and + * whose aspect ratio matches with the specified value. + * + * @param choices The list of sizes that the camera supports for the intended output class + * @param textureViewWidth The width of the texture view relative to sensor coordinate + * @param textureViewHeight The height of the texture view relative to sensor coordinate + * @param maxWidth The maximum width that can be chosen + * @param maxHeight The maximum height that can be chosen + * @param aspectRatio The aspect ratio + * @return The optimal {@code Size}, or an arbitrary one if none were big enough + */ + private static Size chooseOptimalSize( + Size[] choices, + int textureViewWidth, + int textureViewHeight, + int maxWidth, + int maxHeight, + Size aspectRatio) { + + // Collect the supported resolutions that are at least as big as the preview Surface + List bigEnough = new ArrayList<>(); + // Collect the supported resolutions that are smaller than the preview Surface + List notBigEnough = new ArrayList<>(); + int w = aspectRatio.getWidth(); + int h = aspectRatio.getHeight(); + for (Size option : choices) { + if (option.getWidth() <= maxWidth + && option.getHeight() <= maxHeight + && option.getHeight() == option.getWidth() * h / w) { + if (option.getWidth() >= textureViewWidth && option.getHeight() >= textureViewHeight) { + bigEnough.add(option); + } else { + notBigEnough.add(option); + } + } + } + + // Pick the smallest of those big enough. If there is no one big enough, pick the + // largest of those not big enough. + if (bigEnough.size() > 0) { + return Collections.min(bigEnough, new CompareSizesByArea()); + } else if (notBigEnough.size() > 0) { + return Collections.max(notBigEnough, new CompareSizesByArea()); + } else { + Log.e(TAG, "Couldn't find any suitable preview size"); + return choices[0]; + } + } + + public static Camera2BasicFragment newInstance() { + return new Camera2BasicFragment(); + } + + /** Layout the preview and buttons. */ + @Override + public View onCreateView( + LayoutInflater inflater, ViewGroup container, Bundle savedInstanceState) { + return inflater.inflate(R.layout.fragment_camera2_basic, container, false); + } + + /** Connect the buttons to their event handler. */ + @Override + public void onViewCreated(final View view, Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + textView = (TextView) view.findViewById(R.id.text); + } + + /** Load the model and labels. */ + @Override + public void onActivityCreated(Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + try { + classifier = new ImageClassifier(getActivity()); + } catch (IOException e) { + Log.e(TAG, "Failed to initialize an image classifier."); + } + startBackgroundThread(); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + if (textureView.isAvailable()) { + openCamera(textureView.getWidth(), textureView.getHeight()); + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + closeCamera(); + stopBackgroundThread(); + super.onPause(); + } + + @Override + public void onDestroy() { + classifier.close(); + super.onDestroy(); + } + + /** + * Sets up member variables related to camera. + * + * @param width The width of available size for camera preview + * @param height The height of available size for camera preview + */ + private void setUpCameraOutputs(int width, int height) { + Activity activity = getActivity(); + CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + for (String cameraId : manager.getCameraIdList()) { + CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + // We don't use a front facing camera in this sample. + Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); + if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { + continue; + } + + StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + if (map == null) { + continue; + } + + // // For still image captures, we use the largest available size. + Size largest = + Collections.max( + Arrays.asList(map.getOutputSizes(ImageFormat.JPEG)), new CompareSizesByArea()); + imageReader = + ImageReader.newInstance( + largest.getWidth(), largest.getHeight(), ImageFormat.JPEG, /*maxImages*/ 2); + + // Find out if we need to swap dimension to get the preview size relative to sensor + // coordinate. + int displayRotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + // noinspection ConstantConditions + /* Orientation of the camera sensor */ + int sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION); + boolean swappedDimensions = false; + switch (displayRotation) { + case Surface.ROTATION_0: + case Surface.ROTATION_180: + if (sensorOrientation == 90 || sensorOrientation == 270) { + swappedDimensions = true; + } + break; + case Surface.ROTATION_90: + case Surface.ROTATION_270: + if (sensorOrientation == 0 || sensorOrientation == 180) { + swappedDimensions = true; + } + break; + default: + Log.e(TAG, "Display rotation is invalid: " + displayRotation); + } + + Point displaySize = new Point(); + activity.getWindowManager().getDefaultDisplay().getSize(displaySize); + int rotatedPreviewWidth = width; + int rotatedPreviewHeight = height; + int maxPreviewWidth = displaySize.x; + int maxPreviewHeight = displaySize.y; + + if (swappedDimensions) { + rotatedPreviewWidth = height; + rotatedPreviewHeight = width; + maxPreviewWidth = displaySize.y; + maxPreviewHeight = displaySize.x; + } + + if (maxPreviewWidth > MAX_PREVIEW_WIDTH) { + maxPreviewWidth = MAX_PREVIEW_WIDTH; + } + + if (maxPreviewHeight > MAX_PREVIEW_HEIGHT) { + maxPreviewHeight = MAX_PREVIEW_HEIGHT; + } + + previewSize = + chooseOptimalSize( + map.getOutputSizes(SurfaceTexture.class), + rotatedPreviewWidth, + rotatedPreviewHeight, + maxPreviewWidth, + maxPreviewHeight, + largest); + + // We fit the aspect ratio of TextureView to the size of preview we picked. + int orientation = getResources().getConfiguration().orientation; + if (orientation == Configuration.ORIENTATION_LANDSCAPE) { + textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight()); + } else { + textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth()); + } + + this.cameraId = cameraId; + return; + } + } catch (CameraAccessException e) { + e.printStackTrace(); + } catch (NullPointerException e) { + // Currently an NPE is thrown when the Camera2API is used but not supported on the + // device this code runs. + ErrorDialog.newInstance(getString(R.string.camera_error)) + .show(getChildFragmentManager(), FRAGMENT_DIALOG); + } + } + + private String[] getRequiredPermissions() { + Activity activity = getActivity(); + try { + PackageInfo info = + activity + .getPackageManager() + .getPackageInfo(activity.getPackageName(), PackageManager.GET_PERMISSIONS); + String[] ps = info.requestedPermissions; + if (ps != null && ps.length > 0) { + return ps; + } else { + return new String[0]; + } + } catch (Exception e) { + return new String[0]; + } + } + + /** Opens the camera specified by {@link Camera2BasicFragment#cameraId}. */ + private void openCamera(int width, int height) { + if (!checkedPermissions && !allPermissionsGranted()) { + FragmentCompat.requestPermissions(this, getRequiredPermissions(), PERMISSIONS_REQUEST_CODE); + return; + } else { + checkedPermissions = true; + } + setUpCameraOutputs(width, height); + configureTransform(width, height); + Activity activity = getActivity(); + CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) { + throw new RuntimeException("Time out waiting to lock camera opening."); + } + manager.openCamera(cameraId, stateCallback, backgroundHandler); + } catch (CameraAccessException e) { + e.printStackTrace(); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera opening.", e); + } + } + + private boolean allPermissionsGranted() { + for (String permission : getRequiredPermissions()) { + if (ContextCompat.checkSelfPermission(getActivity(), permission) + != PackageManager.PERMISSION_GRANTED) { + return false; + } + } + return true; + } + + @Override + public void onRequestPermissionsResult( + int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + } + + /** Closes the current {@link CameraDevice}. */ + private void closeCamera() { + try { + cameraOpenCloseLock.acquire(); + if (null != captureSession) { + captureSession.close(); + captureSession = null; + } + if (null != cameraDevice) { + cameraDevice.close(); + cameraDevice = null; + } + if (null != imageReader) { + imageReader.close(); + imageReader = null; + } + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera closing.", e); + } finally { + cameraOpenCloseLock.release(); + } + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread(HANDLE_THREAD_NAME); + backgroundThread.start(); + backgroundHandler = new Handler(backgroundThread.getLooper()); + synchronized (lock) { + runClassifier = true; + } + backgroundHandler.post(periodicClassify); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + backgroundHandler = null; + synchronized (lock) { + runClassifier = false; + } + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + /** Takes photos and classify them periodically. */ + private Runnable periodicClassify = + new Runnable() { + @Override + public void run() { + synchronized (lock) { + if (runClassifier) { + classifyFrame(); + } + } + backgroundHandler.post(periodicClassify); + } + }; + + /** Creates a new {@link CameraCaptureSession} for camera preview. */ + private void createCameraPreviewSession() { + try { + SurfaceTexture texture = textureView.getSurfaceTexture(); + assert texture != null; + + // We configure the size of default buffer to be the size of camera preview we want. + texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight()); + + // This is the output Surface we need to start preview. + Surface surface = new Surface(texture); + + // We set up a CaptureRequest.Builder with the output Surface. + previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); + previewRequestBuilder.addTarget(surface); + + // Here, we create a CameraCaptureSession for camera preview. + cameraDevice.createCaptureSession( + Arrays.asList(surface), + new CameraCaptureSession.StateCallback() { + + @Override + public void onConfigured(@NonNull CameraCaptureSession cameraCaptureSession) { + // The camera is already closed + if (null == cameraDevice) { + return; + } + + // When the session is ready, we start displaying the preview. + captureSession = cameraCaptureSession; + try { + // Auto focus should be continuous for camera preview. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AF_MODE, + CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE); + + // Finally, we start displaying the camera preview. + previewRequest = previewRequestBuilder.build(); + captureSession.setRepeatingRequest( + previewRequest, captureCallback, backgroundHandler); + } catch (CameraAccessException e) { + e.printStackTrace(); + } + } + + @Override + public void onConfigureFailed(@NonNull CameraCaptureSession cameraCaptureSession) { + showToast("Failed"); + } + }, + null); + } catch (CameraAccessException e) { + e.printStackTrace(); + } + } + + /** + * Configures the necessary {@link android.graphics.Matrix} transformation to `textureView`. This + * method should be called after the camera preview size is determined in setUpCameraOutputs and + * also the size of `textureView` is fixed. + * + * @param viewWidth The width of `textureView` + * @param viewHeight The height of `textureView` + */ + private void configureTransform(int viewWidth, int viewHeight) { + Activity activity = getActivity(); + if (null == textureView || null == previewSize || null == activity) { + return; + } + int rotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + Matrix matrix = new Matrix(); + RectF viewRect = new RectF(0, 0, viewWidth, viewHeight); + RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth()); + float centerX = viewRect.centerX(); + float centerY = viewRect.centerY(); + if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) { + bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY()); + matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL); + float scale = + Math.max( + (float) viewHeight / previewSize.getHeight(), + (float) viewWidth / previewSize.getWidth()); + matrix.postScale(scale, scale, centerX, centerY); + matrix.postRotate(90 * (rotation - 2), centerX, centerY); + } else if (Surface.ROTATION_180 == rotation) { + matrix.postRotate(180, centerX, centerY); + } + textureView.setTransform(matrix); + } + + /** Classifies a frame from the preview stream. */ + private void classifyFrame() { + if (classifier == null || getActivity() == null || cameraDevice == null) { + showToast("Uninitialized Classifier or invalid context."); + return; + } + Bitmap bitmap = + textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y); + String textToShow = classifier.classifyFrame(bitmap); + bitmap.recycle(); + showToast(textToShow); + } + + /** Compares two {@code Size}s based on their areas. */ + private static class CompareSizesByArea implements Comparator { + + @Override + public int compare(Size lhs, Size rhs) { + // We cast here to ensure the multiplications won't overflow + return Long.signum( + (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight()); + } + } + + /** Shows an error message dialog. */ + public static class ErrorDialog extends DialogFragment { + + private static final String ARG_MESSAGE = "message"; + + public static ErrorDialog newInstance(String message) { + ErrorDialog dialog = new ErrorDialog(); + Bundle args = new Bundle(); + args.putString(ARG_MESSAGE, message); + dialog.setArguments(args); + return dialog; + } + + @Override + public Dialog onCreateDialog(Bundle savedInstanceState) { + final Activity activity = getActivity(); + return new AlertDialog.Builder(activity) + .setMessage(getArguments().getString(ARG_MESSAGE)) + .setPositiveButton( + android.R.string.ok, + new DialogInterface.OnClickListener() { + @Override + public void onClick(DialogInterface dialogInterface, int i) { + activity.finish(); + } + }) + .create(); + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..e7161ddb26b379f9dcf6addefa585ccf6431c055 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.os.Bundle; + +/** Main {@code Activity} class for the Camera app. */ +public class CameraActivity extends Activity { + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_camera); + if (null == savedInstanceState) { + getFragmentManager() + .beginTransaction() + .replace(R.id.container, Camera2BasicFragment.newInstance()) + .commit(); + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java new file mode 100644 index 0000000000000000000000000000000000000000..e7bad4637041d003c1e507d81c0c30404c587653 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.content.res.AssetFileDescriptor; +import android.graphics.Bitmap; +import android.os.SystemClock; +import android.util.Log; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import org.tensorflow.lite.Interpreter; + +/** Classifies images with Tensorflow Lite. */ +public class ImageClassifier { + + /** Tag for the {@link Log}. */ + private static final String TAG = "TfLiteCameraDemo"; + + /** Name of the model file stored in Assets. */ + private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite"; + + /** Name of the label file stored in Assets. */ + private static final String LABEL_PATH = "labels.txt"; + + /** Number of results to show in the UI. */ + private static final int RESULTS_TO_SHOW = 3; + + /** Dimensions of inputs. */ + private static final int DIM_BATCH_SIZE = 1; + + private static final int DIM_PIXEL_SIZE = 3; + + static final int DIM_IMG_SIZE_X = 224; + static final int DIM_IMG_SIZE_Y = 224; + + /* Preallocated buffers for storing image data in. */ + private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y]; + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + private Interpreter tflite; + + /** Labels corresponding to the output of the vision model. */ + private List labelList; + + /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */ + private ByteBuffer imgData = null; + + /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */ + private byte[][] labelProbArray = null; + + private PriorityQueue> sortedLabels = + new PriorityQueue<>( + RESULTS_TO_SHOW, + new Comparator>() { + @Override + public int compare(Map.Entry o1, Map.Entry o2) { + return (o1.getValue()).compareTo(o2.getValue()); + } + }); + + /** Initializes an {@code ImageClassifier}. */ + ImageClassifier(Activity activity) throws IOException { + tflite = new Interpreter(loadModelFile(activity)); + labelList = loadLabelList(activity); + imgData = + ByteBuffer.allocateDirect( + DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE); + imgData.order(ByteOrder.nativeOrder()); + labelProbArray = new byte[1][labelList.size()]; + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + } + + /** Classifies a frame from the preview stream. */ + String classifyFrame(Bitmap bitmap) { + if (tflite == null) { + Log.e(TAG, "Image classifier has not been initialized; Skipped."); + return "Uninitialized Classifier."; + } + convertBitmapToByteBuffer(bitmap); + // Here's where the magic happens!!! + long startTime = SystemClock.uptimeMillis(); + tflite.run(imgData, labelProbArray); + long endTime = SystemClock.uptimeMillis(); + Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime)); + String textToShow = printTopKLabels(); + textToShow = Long.toString(endTime - startTime) + "ms" + textToShow; + return textToShow; + } + + /** Closes tflite to release resources. */ + public void close() { + tflite.close(); + tflite = null; + } + + /** Reads label list from Assets. */ + private List loadLabelList(Activity activity) throws IOException { + List labelList = new ArrayList(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH))); + String line; + while ((line = reader.readLine()) != null) { + labelList.add(line); + } + reader.close(); + return labelList; + } + + /** Memory-map the model file in Assets. */ + private MappedByteBuffer loadModelFile(Activity activity) throws IOException { + AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } + + /** Writes Image data into a {@code ByteBuffer}. */ + private void convertBitmapToByteBuffer(Bitmap bitmap) { + if (imgData == null) { + return; + } + imgData.rewind(); + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + // Convert the image to floating point. + int pixel = 0; + long startTime = SystemClock.uptimeMillis(); + for (int i = 0; i < DIM_IMG_SIZE_X; ++i) { + for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) { + final int val = intValues[pixel++]; + imgData.put((byte) ((val >> 16) & 0xFF)); + imgData.put((byte) ((val >> 8) & 0xFF)); + imgData.put((byte) (val & 0xFF)); + } + } + long endTime = SystemClock.uptimeMillis(); + Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime)); + } + + /** Prints top-K labels, to be shown in UI as the results. */ + private String printTopKLabels() { + for (int i = 0; i < labelList.size(); ++i) { + sortedLabels.add( + new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f)); + if (sortedLabels.size() > RESULTS_TO_SHOW) { + sortedLabels.poll(); + } + } + String textToShow = ""; + final int size = sortedLabels.size(); + for (int i = 0; i < size; ++i) { + Map.Entry label = sortedLabels.poll(); + textToShow = "\n" + label.getKey() + ":" + Float.toString(label.getValue()) + textToShow; + } + return textToShow; + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..e0a70008b10b98162b4710385e21ac65333f1231 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..c22509d8dfccae14d9470e3042a9ed5b469ca2c9 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png new file mode 100644 index 0000000000000000000000000000000000000000..a84e3ef52c6dce90ccfa98f64db25fad7a8f0289 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..520c2dd100b092fad5987dc1b41575e1681b459c Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..d68af39186ca9cd2bc755cad8397467a11844a1d Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..1347b091983ebd9d3d58e29194b9335b6c138a2b Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..15e419b7ccd88651bd21dac36853a827fc4075b8 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..fd933333b71590608d91201aad29553f9b365b6a Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..342ce34e1663960d8d7050a9be57face3571d336 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml new file mode 100644 index 0000000000000000000000000000000000000000..a84f1bbfa0cb48a3fc335c9bc4aa7d8e93d20e75 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml @@ -0,0 +1,50 @@ + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml new file mode 100644 index 0000000000000000000000000000000000000000..286e549c6569cef4b7a9e46f9c73e6f43b6d3045 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml @@ -0,0 +1,22 @@ + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml new file mode 100644 index 0000000000000000000000000000000000000000..15305c436e0d997af15a326ab4027ea713ed8098 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml @@ -0,0 +1,45 @@ + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml new file mode 100644 index 0000000000000000000000000000000000000000..22074a2bdbaf60efff64d98a0788ef797a966f80 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml @@ -0,0 +1,24 @@ + + + + + + + @dimen/margin_huge + @dimen/margin_medium + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..03d1974183dd645178c07d247d61b83d067806be --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml @@ -0,0 +1,25 @@ + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..8c1ea66f28907ac211f355f4220ff4582cfb31eb --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml @@ -0,0 +1,22 @@ + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml new file mode 100644 index 0000000000000000000000000000000000000000..0a71dbd0e8010f5e3a176de1f7e8321331289f7c --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml @@ -0,0 +1,30 @@ + + + + + TfLiteCameraDemo + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml new file mode 100644 index 0000000000000000000000000000000000000000..4b75d2b2bda0f95166d0442ebae19cedcad162d8 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml @@ -0,0 +1,19 @@ + + + + #cc4285f4 + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml new file mode 100644 index 0000000000000000000000000000000000000000..a08ec3eb629250a727cec49a822375fe5569f455 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml @@ -0,0 +1,24 @@ + + + Picture + Info + This sample needs camera permission. + This device doesn\'t support Camera2 API. + NN:On + NN:Off + Use NNAPI + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..3f3bdfb49480e779c108cd15da854ae82a118d52 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml @@ -0,0 +1,18 @@ + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/build.gradle b/tensorflow/contrib/lite/java/demo/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..b78a0b86c939620b6f05483ce45c4d3ef0ef595e --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/build.gradle @@ -0,0 +1,23 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. + +buildscript { + repositories { + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:2.3.1' + + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + jcenter() + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} diff --git a/tensorflow/contrib/lite/java/demo/gradle.properties b/tensorflow/contrib/lite/java/demo/gradle.properties new file mode 100644 index 0000000000000000000000000000000000000000..aac7c9b4614ccfde6c721f24994cf30885a791d0 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradle.properties @@ -0,0 +1,17 @@ +# Project-wide Gradle settings. + +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. + +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html + +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx1536m + +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..13372aef5e24af05341d49695ee84e5f9b594659 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar differ diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000000000000000000000000000000000..fa7a38a0e43eecd1e7292dd49efa79a5d0742e2a --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Thu Sep 28 09:01:41 PDT 2017 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-3.3-all.zip diff --git a/tensorflow/contrib/lite/java/demo/gradlew b/tensorflow/contrib/lite/java/demo/gradlew new file mode 100755 index 0000000000000000000000000000000000000000..9d82f78915133e1c35a6ea51252590fb38efac2f --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradlew @@ -0,0 +1,160 @@ +#!/usr/bin/env bash + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn ( ) { + echo "$*" +} + +die ( ) { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; +esac + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules +function splitJvmOpts() { + JVM_OPTS=("$@") +} +eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS +JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" + +exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" diff --git a/tensorflow/contrib/lite/java/demo/gradlew.bat b/tensorflow/contrib/lite/java/demo/gradlew.bat new file mode 100644 index 0000000000000000000000000000000000000000..8a0b282aa6885fb573c106b3551f7275c5f17e8e --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradlew.bat @@ -0,0 +1,90 @@ +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windowz variants + +if not "%OS%" == "Windows_NT" goto win9xME_args +if "%@eval[2+2]" == "4" goto 4NT_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* +goto execute + +:4NT_args +@rem Get arguments from the 4NT Shell from JP Software +set CMD_LINE_ARGS=%$ + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/tensorflow/contrib/lite/java/demo/settings.gradle b/tensorflow/contrib/lite/java/demo/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e7b4def49cb53d9aa04228dd3edb14c9e635e003 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/settings.gradle @@ -0,0 +1 @@ +include ':app' diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java new file mode 100644 index 0000000000000000000000000000000000000000..d63c299589d2e8ce1051a52d29b533ed126bbcf7 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** Type of elements in a {@link TfLiteTensor}. */ +enum DataType { + /** 32-bit single precision floating point. */ + FLOAT32(1), + + /** 32-bit signed integer. */ + INT32(2), + + /** 8-bit unsigned integer. */ + UINT8(3), + + /** 64-bit signed integer. */ + INT64(4), + + /** A {@link ByteBuffer}. */ + BYTEBUFFER(999); + + private final int value; + + DataType(int value) { + this.value = value; + } + + /** Corresponding value of the kTfLite* enum in the TensorFlow Lite CC API. */ + int getNumber() { + return value; + } + + /** Converts an integer to the corresponding type. */ + static DataType fromNumber(int c) { + for (DataType t : values) { + if (t.value == c) { + return t; + } + } + throw new IllegalArgumentException( + "DataType " + c + " is not recognized in Java (version " + TensorFlowLite.version() + ")"); + } + + /** Returns byte size of the type. */ + int elemByteSize() { + switch (this) { + case FLOAT32: + return 4; + case INT32: + return 4; + case UINT8: + return 1; + case INT64: + return 8; + case BYTEBUFFER: + return 1; + } + throw new IllegalArgumentException("DataType " + this + " is not supported yet"); + } + + // Cached to avoid copying it + private static final DataType[] values = values(); +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java new file mode 100644 index 0000000000000000000000000000000000000000..dd883d69d2065236ee29012b9bde99972aefbcf7 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -0,0 +1,172 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.io.File; +import java.nio.MappedByteBuffer; +import java.util.HashMap; +import java.util.Map; +import javax.validation.constraints.NotNull; + +/** + * Driver class to drive model inference with TensorFlow Lite. + * + *

A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations + * are executed for model inference. + * + *

For example, if a model takes only one input and returns only one output: + * + *

{@code
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ *   interpreter.run(input, output);
+ * }
+ * }
+ * + *

If a model takes multiple inputs or outputs: + * + *

{@code
+ * Object[] inputs = {input0, input1, ...};
+ * Map map_of_indices_to_outputs = new HashMap<>();
+ * float[][][] ith_output = new float[3][2][4];
+ * map_of_indices_to_outputs.put(i, ith_output);
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ *   interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
+ * }
+ * }
+ * + *

Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite + * model with Toco. + * + *

WARNING:Instances of a {@code Interpreter} is not thread-safe. A {@code + * Interpreter} owns resources that must be explicitly freed by invoking {@link #close()} + */ +public final class Interpreter implements AutoCloseable { + + /** + * Initializes a {@code Interpreter} + * + * @param modelFile: a File of a pre-trained TF Lite model. + */ + public Interpreter(@NotNull File modelFile) { + if (modelFile == null) { + return; + } + wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath()); + } + + /** + * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. + * + *

The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code + * Interpreter}. + */ + public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer) { + wrapper = new NativeInterpreterWrapper(mappedByteBuffer); + } + + /** + * Runs model inference if the model takes only one input, and provides only one output. + * + * @param input an array or multidimensional array, or a {@link ByteBuffer} of primitive types + * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large + * input data. When {@link ByteBuffer} is used, its content should remain unchanged until + * model inference is done. + * @param output a multidimensional array of output data. + */ + public void run(@NotNull Object input, @NotNull Object output) { + Object[] inputs = {input}; + Map outputs = new HashMap<>(); + outputs.put(0, output); + runForMultipleInputsOutputs(inputs, outputs); + } + + /** + * Runs model inference if the model takes multiple inputs, or returns multiple outputs. + * + * @param inputs an array of input data. The inputs should be in the same order as inputs of the + * model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of + * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred + * way to pass large input data. When {@link ByteBuffer} is used, its content should remain + * unchanged until model inference is done. + * @param outputs a map mapping output indices to multidimensional arrays of output data. It only + * needs to keep entries for the outputs to be used. + */ + public void runForMultipleInputsOutputs( + @NotNull Object[] inputs, @NotNull Map outputs) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + Tensor[] tensors = wrapper.run(inputs); + if (outputs == null || tensors == null || outputs.size() > tensors.length) { + throw new IllegalArgumentException("Outputs do not match with model outputs."); + } + final int size = tensors.length; + for (Integer idx : outputs.keySet()) { + if (idx == null || idx < 0 || idx >= size) { + throw new IllegalArgumentException( + String.format("Invalid index of output %d (should be in range [0, %d))", idx, size)); + } + tensors[idx].copyTo(outputs.get(idx)); + } + } + + /** + * Resizes idx-th input of the native model to the given dims. + * + *

IllegalArgumentException will be thrown if it fails to resize. + */ + public void resizeInput(int idx, @NotNull int[] dims) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + wrapper.resizeInput(idx, dims); + } + + /** + * Gets index of an input given the op name of the input. + * + *

IllegalArgumentException will be thrown if the op name does not exist in the model file used + * to initialize the {@link Interpreter}. + */ + public int getInputIndex(String opName) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + return wrapper.getInputIndex(opName); + } + + /** + * Gets index of an output given the op name of the output. + * + *

IllegalArgumentException will be thrown if the op name does not exist in the model file used + * to initialize the {@link Interpreter}. + */ + public int getOutputIndex(String opName) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + return wrapper.getOutputIndex(opName); + } + + /** Release resources associated with the {@code Interpreter}. */ + @Override + public void close() { + wrapper.close(); + wrapper = null; + } + + NativeInterpreterWrapper wrapper; +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..1939a078ad8031b99620773c9b91335c4e8f7b22 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -0,0 +1,276 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.util.HashMap; +import java.util.Map; + +/** + * A wrapper wraps native interpreter and controls model execution. + * + *

WARNING: Resources consumed by the {@code NativeInterpreterWrapper} object must be + * explicitly freed by invoking the {@link #close()} method when the {@code + * NativeInterpreterWrapper} object is no longer needed. + */ +final class NativeInterpreterWrapper implements AutoCloseable { + + NativeInterpreterWrapper(String modelPath) { + errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); + modelHandle = createModel(modelPath, errorHandle); + interpreterHandle = createInterpreter(modelHandle); + } + + /** + * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer}. The + * MappedByteBuffer should not be modified after the construction of a {@code + * NativeInterpreterWrapper}. + */ + NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) { + modelByteBuffer = mappedByteBuffer; + errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); + modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); + interpreterHandle = createInterpreter(modelHandle); + } + + /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ + @Override + public void close() { + delete(errorHandle, modelHandle, interpreterHandle); + errorHandle = 0; + modelHandle = 0; + interpreterHandle = 0; + modelByteBuffer = null; + inputsIndexes = null; + outputsIndexes = null; + } + + /** Sets inputs, runs model inference and returns outputs. */ + Tensor[] run(Object[] inputs) { + if (inputs == null || inputs.length == 0) { + throw new IllegalArgumentException("Invalid inputs. Inputs should not be null or empty."); + } + int[] dataTypes = new int[inputs.length]; + Object[] sizes = new Object[inputs.length]; + int[] numsOfBytes = new int[inputs.length]; + for (int i = 0; i < inputs.length; ++i) { + DataType dataType = dataTypeOf(inputs[i]); + dataTypes[i] = dataType.getNumber(); + if (dataType == DataType.BYTEBUFFER) { + ByteBuffer buffer = (ByteBuffer) inputs[i]; + if (buffer.order() != ByteOrder.nativeOrder()) { + throw new IllegalArgumentException( + "Invalid ByteBuffer. It shoud use ByteOrder.nativeOrder()."); + } + numsOfBytes[i] = buffer.limit(); + sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]); + } else if (isNonEmptyArray(inputs[i])) { + int[] dims = shapeOf(inputs[i]); + sizes[i] = dims; + numsOfBytes[i] = dataType.elemByteSize() * numElements(dims); + } else { + throw new IllegalArgumentException( + String.format( + "%d-th element of the %d inputs is not an array or a ByteBuffer.", + i, inputs.length)); + } + } + long[] outputsHandles = + run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs); + if (outputsHandles == null || outputsHandles.length == 0) { + throw new IllegalStateException("Interpreter has no outputs."); + } + Tensor[] outputs = new Tensor[outputsHandles.length]; + for (int i = 0; i < outputsHandles.length; ++i) { + outputs[i] = Tensor.fromHandle(outputsHandles[i]); + } + return outputs; + } + + /** Resizes dimensions of a specific input. */ + void resizeInput(int idx, int[] dims) { + resizeInput(interpreterHandle, errorHandle, idx, dims); + } + + void setUseNNAPI(boolean useNNAPI) { + useNNAPI(interpreterHandle, useNNAPI); + } + + /** Gets index of an input given its name. */ + int getInputIndex(String name) { + if (inputsIndexes == null) { + String[] names = getInputNames(interpreterHandle); + inputsIndexes = new HashMap<>(); + if (names != null) { + for (int i = 0; i < names.length; ++i) { + inputsIndexes.put(names[i], i); + } + } + } + if (inputsIndexes.containsKey(name)) { + return inputsIndexes.get(name); + } else { + throw new IllegalArgumentException( + String.format( + "%s is not a valid name for any input. The indexes of the inputs are %s", + name, inputsIndexes.toString())); + } + } + + /** Gets index of an output given its name. */ + int getOutputIndex(String name) { + if (outputsIndexes == null) { + String[] names = getOutputNames(interpreterHandle); + outputsIndexes = new HashMap<>(); + if (names != null) { + for (int i = 0; i < names.length; ++i) { + outputsIndexes.put(names[i], i); + } + } + } + if (outputsIndexes.containsKey(name)) { + return outputsIndexes.get(name); + } else { + throw new IllegalArgumentException( + String.format( + "%s is not a valid name for any output. The indexes of the outputs are %s", + name, outputsIndexes.toString())); + } + } + + static int numElements(int[] shape) { + if (shape == null) { + return 0; + } + int n = 1; + for (int i = 0; i < shape.length; i++) { + n *= shape[i]; + } + return n; + } + + static boolean isNonEmptyArray(Object o) { + return (o != null && o.getClass().isArray() && Array.getLength(o) != 0); + } + + /** Returns the type of the data. */ + static DataType dataTypeOf(Object o) { + if (o != null) { + Class c = o.getClass(); + while (c.isArray()) { + c = c.getComponentType(); + } + if (float.class.equals(c)) { + return DataType.FLOAT32; + } else if (int.class.equals(c)) { + return DataType.INT32; + } else if (byte.class.equals(c)) { + return DataType.UINT8; + } else if (long.class.equals(c)) { + return DataType.INT64; + } else if (ByteBuffer.class.isInstance(o)) { + return DataType.BYTEBUFFER; + } + } + throw new IllegalArgumentException("cannot resolve DataType of " + o.getClass().getName()); + } + + /** Returns the shape of an object as an int array. */ + static int[] shapeOf(Object o) { + int size = numDimensions(o); + int[] dimensions = new int[size]; + fillShape(o, 0, dimensions); + return dimensions; + } + + static int numDimensions(Object o) { + if (o == null || !o.getClass().isArray()) { + return 0; + } + if (Array.getLength(o) == 0) { + throw new IllegalArgumentException("array lengths cannot be 0."); + } + return 1 + numDimensions(Array.get(o, 0)); + } + + static void fillShape(Object o, int dim, int[] shape) { + if (shape == null || dim == shape.length) { + return; + } + final int len = Array.getLength(o); + if (shape[dim] == 0) { + shape[dim] = len; + } else if (shape[dim] != len) { + throw new IllegalArgumentException( + String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); + } + for (int i = 0; i < len; ++i) { + fillShape(Array.get(o, i), dim + 1, shape); + } + } + + private static final int ERROR_BUFFER_SIZE = 512; + + private long errorHandle; + + private long interpreterHandle; + + private long modelHandle; + + private int inputSize; + + private MappedByteBuffer modelByteBuffer; + + private Map inputsIndexes; + + private Map outputsIndexes; + + private static native String[] getInputNames(long interpreterHandle); + + private static native String[] getOutputNames(long interpreterHandle); + + private static native void resizeInput( + long interpreterHandle, long errorHandle, int inputIdx, int[] dims); + + private static native void useNNAPI(long interpreterHandle, boolean state); + + private static native long createErrorReporter(int size); + + private static native long createModel(String modelPathOrBuffer, long errorHandle); + + private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle); + + private static native long createInterpreter(long modelHandle); + + private static native long[] run( + long interpreterHandle, + long errorHandle, + Object[] sizes, + int[] dtypes, + int[] numsOfBytes, + Object[] values); + + private static native void delete(long errorHandle, long modelHandle, long interpreterHandle); + + private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes); + + static { + TensorFlowLite.init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java new file mode 100644 index 0000000000000000000000000000000000000000..54ace6c63ce5bd1b38be744176d0378e3cc8a1d3 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.util.Arrays; + +/** + * A typed multi-dimensional array used in Tensorflow Lite. + * + *

The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not + * needed to be closed here. + */ +final class Tensor { + + static Tensor fromHandle(long nativeHandle) { + return new Tensor(nativeHandle); + } + + /** Reads Tensor content into an array. */ + T copyTo(T dst) { + if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) { + throw new IllegalArgumentException( + String.format( + "Cannot convert an TensorFlowLite tensor with type %s to a Java object of " + + "type %s (which is compatible with the TensorFlowLite type %s)", + dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst))); + } + int[] dstShape = NativeInterpreterWrapper.shapeOf(dst); + if (!Arrays.equals(dstShape, shapeCopy)) { + throw new IllegalArgumentException( + String.format( + "Shape of output target %s does not match with the shape of the Tensor %s.", + Arrays.toString(dstShape), Arrays.toString(shapeCopy))); + } + readMultiDimensionalArray(nativeHandle, dst); + return dst; + } + + final long nativeHandle; + final DataType dtype; + final int[] shapeCopy; + + private Tensor(long nativeHandle) { + this.nativeHandle = nativeHandle; + this.dtype = DataType.fromNumber(dtype(nativeHandle)); + this.shapeCopy = shape(nativeHandle); + } + + private static native int dtype(long handle); + + private static native int[] shape(long handle); + + private static native void readMultiDimensionalArray(long handle, Object value); + + static { + TensorFlowLite.init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java new file mode 100644 index 0000000000000000000000000000000000000000..711638a9f995ce270cd362b93a7bcfca990430dc --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** Static utility methods loading the TensorFlowLite runtime. */ +public final class TensorFlowLite { + + private static final String LIBNAME = "tensorflowlite_jni"; + + private TensorFlowLite() {} + + /** Returns the version of the underlying TensorFlowLite runtime. */ + public static native String version(); + + /** + * Load the TensorFlowLite runtime C library. + */ + static boolean init() { + try { + System.loadLibrary(LIBNAME); + return true; + } catch (UnsatisfiedLinkError e) { + System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage()); + return false; + } + } + + static { + init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..68e6a0f57810f6d9675a5d1193601e43e172ab74 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java @@ -0,0 +1,17 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** Defines classes to load and execute TensorFlowLite models. */ +package org.tensorflow.lite; diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..15806d57c8ed7a45d2db9b80e2aab8e22349ee3e --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/BUILD @@ -0,0 +1,108 @@ +# Description: +# Java Native Interface (JNI) library intended for implementing the +# TensorFlow Lite Java API using the TensorFlow Lite CC library. + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "native_framework_only", + srcs = [ + "exception_jni.cc", + "nativeinterpreterwrapper_jni.cc", + "tensor_jni.cc", + "tensorflow_lite_jni.cc", + ] + select({ + # The Android toolchain makes "jni.h" available in the include path. + # For non-Android toolchains, generate jni.h and jni_md.h. + "//tensorflow:android": [], + "//conditions:default": [ + ":jni.h", + ":jni_md.h", + ], + }), + hdrs = [ + "exception_jni.h", + "nativeinterpreterwrapper_jni.h", + "tensor_jni.h", + "tensorflow_lite_jni.h", + ], + copts = tflite_copts(), + includes = select({ + "//tensorflow:android": [], + "//conditions:default": ["."], + }), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + ], + alwayslink = 1, +) + +# Silly rules to make +# #include +# in the source headers work +# (in combination with the "includes" attribute of the tf_cuda_library rule +# above. Not needed when using the Android toolchain). +# +# Inspired from: +# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD +# but hopefully there is a simpler alternative to this. +genrule( + name = "copy_jni_h", + srcs = ["@bazel_tools//tools/jdk:jni_header"], + outs = ["jni.h"], + cmd = "cp -f $< $@", +) + +genrule( + name = "copy_jni_md_h", + srcs = select({ + "//tensorflow:darwin": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], + "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], + }), + outs = ["jni_md.h"], + cmd = "cp -f $< $@", +) + +# This includes all ops. If you want a smaller binary, you should copy and +# modify builtin_ops_jni.cc. You should then link your binary against both +# ":native_framework_only" and your own version of ":native_builtin_ops". +cc_library( + name = "native", + srcs = [ + "builtin_ops_jni.cc", + ], + copts = tflite_copts(), + deps = [ + ":native_framework_only", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +exports_files( + [ + "version_script.lds", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..cce356370fa770de3e44438f08470077fb07c04c --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/register.h" + +namespace tflite { + +// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in +// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the +// builtin ops. For smaller binary sizes users should avoid linking this in, and +// should provide a custom make CreateOpResolver() instead. +std::unique_ptr CreateOpResolver() { // NOLINT + return std::unique_ptr( + new tflite::ops::builtin::BuiltinOpResolver()); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..1578c9e3ddd034ad9ce17c8c3ae6c942258e2a55 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" + +const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException"; +const char kIllegalStateException[] = "java/lang/IllegalStateException"; +const char kNullPointerException[] = "java/lang/NullPointerException"; +const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException"; +const char kUnsupportedOperationException[] = + "java/lang/UnsupportedOperationException"; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + const size_t max_msg_len = 512; + auto* message = static_cast(malloc(max_msg_len)); + if (vsnprintf(message, max_msg_len, fmt, args) >= 0) { + env->ThrowNew(env->FindClass(clazz), message); + } else { + env->ThrowNew(env->FindClass(clazz), ""); + } + free(message); + va_end(args); +} + +BufferErrorReporter::BufferErrorReporter(JNIEnv* env, int limit) { + buffer_ = new char[limit]; + if (!buffer_) { + throwException(env, kNullPointerException, + "Malloc of BufferErrorReporter to hold %d char failed.", + limit); + return; + } + start_idx_ = 0; + end_idx_ = limit - 1; +} + +BufferErrorReporter::~BufferErrorReporter() { delete[] buffer_; } + +int BufferErrorReporter::Report(const char* format, va_list args) { + int size = 0; + if (start_idx_ < end_idx_) { + size = vsnprintf(buffer_ + start_idx_, end_idx_ - start_idx_, format, args); + } + start_idx_ += size; + return size; +} + +const char* BufferErrorReporter::CachedErrorMessage() { return buffer_; } diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.h b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..3ffff052df73c5cb21bb6522d31dc615c38f7d1f --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ + +#include +#include "tensorflow/contrib/lite/error_reporter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +extern const char kIllegalArgumentException[]; +extern const char kIllegalStateException[]; +extern const char kNullPointerException[]; +extern const char kIndexOutOfBoundsException[]; +extern const char kUnsupportedOperationException[]; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...); + +class BufferErrorReporter : public tflite::ErrorReporter { + public: + BufferErrorReporter(JNIEnv* env, int limit); + virtual ~BufferErrorReporter(); + int Report(const char* format, va_list args) override; + const char* CachedErrorMessage(); + + private: + char* buffer_; + int start_idx_ = 0; + int end_idx_ = 0; +}; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc6462eb5466e14769f94c5103984f5201b4b8dc --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -0,0 +1,446 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" + +namespace { + +const int kByteBufferValue = 999; +const int kBufferSize = 256; + +tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to Interpreter."); + return nullptr; + } + return reinterpret_cast(handle); +} + +tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, "Invalid handle to model."); + return nullptr; + } + return reinterpret_cast(handle); +} + +BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to ErrorReporter."); + return nullptr; + } + return reinterpret_cast(handle); +} + +std::vector convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { + int size = static_cast(env->GetArrayLength(inputs)); + std::vector outputs(size, 0); + jint* ptr = env->GetIntArrayElements(inputs, nullptr); + if (ptr == nullptr) { + throwException(env, kIllegalArgumentException, + "Empty dimensions of input array."); + return {}; + } + for (int i = 0; i < size; ++i) { + outputs[i] = ptr[i]; + } + env->ReleaseIntArrayElements(inputs, ptr, JNI_ABORT); + return outputs; +} + +bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; } + +TfLiteType resolveDataType(jint data_type) { + switch (data_type) { + case 1: + return kTfLiteFloat32; + case 2: + return kTfLiteInt32; + case 3: + return kTfLiteUInt8; + case 4: + return kTfLiteInt64; + default: + return kTfLiteNoType; + } +} + +void printDims(char* buffer, int max_size, int* dims, int num_dims) { + if (max_size <= 0) return; + buffer[0] = '?'; + int size = 1; + for (int i = 1; i < num_dims; ++i) { + if (max_size > size) { + int written_size = + snprintf(buffer + size, max_size - size, ",%d", dims[i]); + if (written_size < 0) return; + size += written_size; + } + } +} + +TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter, + const int input_size, jintArray data_types, + jintArray nums_of_bytes, jobjectArray values, + jobjectArray sizes) { + if (input_size != interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Expected num of inputs is %d but got %d", + interpreter->inputs().size(), input_size); + return kTfLiteError; + } + if (input_size != env->GetArrayLength(data_types) || + input_size != env->GetArrayLength(nums_of_bytes) || + input_size != env->GetArrayLength(values)) { + throwException(env, kIllegalArgumentException, + "Arrays in arguments should be of the same length, but got " + "%d sizes, %d data_types, %d nums_of_bytes, and %d values", + input_size, env->GetArrayLength(data_types), + env->GetArrayLength(nums_of_bytes), + env->GetArrayLength(values)); + return kTfLiteError; + } + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + TfLiteTensor* target = interpreter->tensor(input_idx); + jintArray dims = + static_cast(env->GetObjectArrayElement(sizes, i)); + int num_dims = static_cast(env->GetArrayLength(dims)); + if (target->dims->size != num_dims) { + throwException(env, kIllegalArgumentException, + "%d-th input should have %d dimensions, but found %d " + "dimensions", + i, target->dims->size, num_dims); + return kTfLiteError; + } + jint* ptr = env->GetIntArrayElements(dims, nullptr); + for (int j = 1; j < num_dims; ++j) { + if (target->dims->data[j] != ptr[j]) { + std::unique_ptr expected_dims(new char[kBufferSize]); + std::unique_ptr obtained_dims(new char[kBufferSize]); + printDims(expected_dims.get(), kBufferSize, target->dims->data, + num_dims); + printDims(obtained_dims.get(), kBufferSize, ptr, num_dims); + throwException(env, kIllegalArgumentException, + "%d-th input dimension should be [%s], but found [%s]", + i, expected_dims.get(), obtained_dims.get()); + env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); + return kTfLiteError; + } + } + env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); + env->DeleteLocalRef(dims); + if (env->ExceptionCheck()) return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter, + int input_size, jobjectArray sizes) { + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + jintArray dims = + static_cast(env->GetObjectArrayElement(sizes, i)); + TfLiteStatus status = interpreter->ResizeInputTensor( + input_idx, convertJIntArrayToVector(env, dims)); + if (status != kTfLiteOk) { + return status; + } + env->DeleteLocalRef(dims); + if (env->ExceptionCheck()) return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter, + int input_size, jintArray data_types, + jintArray nums_of_bytes, jobjectArray values) { + jint* data_type = env->GetIntArrayElements(data_types, nullptr); + jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr); + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + TfLiteTensor* target = interpreter->tensor(input_idx); + jobject value = env->GetObjectArrayElement(values, i); + bool is_byte_buffer = isByteBuffer(data_type[i]); + if (is_byte_buffer) { + writeByteBuffer(env, value, &(target->data.raw), + static_cast(num_bytes[i])); + } else { + TfLiteType type = resolveDataType(data_type[i]); + if (type != target->type) { + throwException(env, kIllegalArgumentException, + "DataType (%d) of input data does not match with the " + "DataType (%d) of model inputs.", + type, target->type); + return kTfLiteError; + } + writeMultiDimensionalArray(env, value, target->type, target->dims->size, + &(target->data.raw), + static_cast(num_bytes[i])); + } + env->DeleteLocalRef(value); + if (env->ExceptionCheck()) return kTfLiteError; + } + env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT); + env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT); + return kTfLiteOk; +} + +} // namespace + +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + jclass string_class = env->FindClass("java/lang/String"); + if (string_class == nullptr) { + throwException(env, kUnsupportedOperationException, + "Can not find java/lang/String class to get input names."); + return nullptr; + } + size_t size = interpreter->inputs().size(); + jobjectArray names = static_cast( + env->NewObjectArray(size, string_class, env->NewStringUTF(""))); + for (int i = 0; i < size; ++i) { + env->SetObjectArrayElement(names, i, + env->NewStringUTF(interpreter->GetInputName(i))); + } + return names; +} + +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + jclass string_class = env->FindClass("java/lang/String"); + if (string_class == nullptr) { + throwException(env, kUnsupportedOperationException, + "Can not find java/lang/String class to get output names."); + return nullptr; + } + size_t size = interpreter->outputs().size(); + jobjectArray names = static_cast( + env->NewObjectArray(size, string_class, env->NewStringUTF(""))); + for (int i = 0; i < size; ++i) { + env->SetObjectArrayElement( + names, i, env->NewStringUTF(interpreter->GetOutputName(i))); + } + return names; +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, + jclass clazz, + jlong handle, + jboolean state) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return; + interpreter->UseNNAPI(static_cast(state)); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( + JNIEnv* env, jclass clazz, jint size) { + BufferErrorReporter* error_reporter = + new BufferErrorReporter(env, static_cast(size)); + return reinterpret_cast(error_reporter); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( + JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; + const char* path = env->GetStringUTFChars(model_file, nullptr); + auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter); + if (!model) { + throwException(env, kIllegalArgumentException, + "Contents of %s does not encode a valid TensorFlowLite " + "model: %s", + path, error_reporter->CachedErrorMessage()); + env->ReleaseStringUTFChars(model_file, path); + return 0; + } + env->ReleaseStringUTFChars(model_file, path); + return reinterpret_cast(model.release()); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( + JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; + const char* buf = + static_cast(env->GetDirectBufferAddress(model_buffer)); + jlong capacity = env->GetDirectBufferCapacity(model_buffer); + auto model = tflite::FlatBufferModel::BuildFromBuffer( + buf, static_cast(capacity), error_reporter); + if (!model) { + throwException(env, kIllegalArgumentException, + "MappedByteBuffer does not encode a valid TensorFlowLite " + "model: %s", + error_reporter->CachedErrorMessage()); + return 0; + } + return reinterpret_cast(model.release()); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( + JNIEnv* env, jclass clazz, jlong model_handle) { + tflite::FlatBufferModel* model = convertLongToModel(env, model_handle); + if (model == nullptr) return 0; + auto resolver = ::tflite::CreateOpResolver(); + std::unique_ptr interpreter; + tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter); + return reinterpret_cast(interpreter.release()); +} + +// Sets inputs, runs inference, and returns outputs as long handles. +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, + jobjectArray values) { + tflite::Interpreter* interpreter = + convertLongToInterpreter(env, interpreter_handle); + if (interpreter == nullptr) return nullptr; + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return nullptr; + const int input_size = env->GetArrayLength(sizes); + // validates inputs + TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types, + nums_of_bytes, values, sizes); + if (status != kTfLiteOk) return nullptr; + // resizes inputs + status = resizeInputs(env, interpreter, input_size, sizes); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, "Can not resize the input: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // allocates memory + status = interpreter->AllocateTensors(); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, + "Can not allocate memory for the given inputs: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // sets inputs + status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes, + values); + if (status != kTfLiteOk) return nullptr; + // runs inference + if (interpreter->Invoke() != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Failed to run on the given Interpreter: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // returns outputs + const std::vector& results = interpreter->outputs(); + if (results.empty()) { + throwException(env, kIllegalArgumentException, + "The Interpreter does not have any outputs."); + return nullptr; + } + jlongArray outputs = env->NewLongArray(results.size()); + size_t size = results.size(); + for (int i = 0; i < size; ++i) { + TfLiteTensor* source = interpreter->tensor(results[i]); + jlong output = reinterpret_cast(source); + env->SetLongArrayRegion(outputs, i, 1, &output); + } + return outputs; +} + +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( + JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + const int idx = static_cast(input_idx); + if (input_idx >= interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Out of range: Failed to get %d-th input out of %d inputs", + input_idx, interpreter->inputs().size()); + return nullptr; + } + TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]); + int size = target->dims->size; + int expected_num_bytes = elementByteSize(target->type); + for (int i = 0; i < size; ++i) { + expected_num_bytes *= target->dims->data[i]; + } + if (num_bytes != expected_num_bytes) { + throwException(env, kIllegalArgumentException, + "Failed to get input dimensions. %d-th input should have" + " %d bytes, but found %d bytes.", + idx, expected_num_bytes, num_bytes); + return nullptr; + } + jintArray outputs = env->NewIntArray(size); + env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0])); + return outputs; +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jint input_idx, jintArray dims) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return; + tflite::Interpreter* interpreter = + convertLongToInterpreter(env, interpreter_handle); + if (interpreter == nullptr) return; + const int idx = static_cast(input_idx); + if (idx < 0 || idx >= interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Can not resize %d-th input for a model having %d inputs.", + idx, interpreter->inputs().size()); + } + TfLiteStatus status = interpreter->ResizeInputTensor( + interpreter->inputs()[idx], convertJIntArrayToVector(env, dims)); + if (status != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Failed to resize %d-th input: %s", idx, + error_reporter->CachedErrorMessage()); + } +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( + JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle, + jlong interpreter_handle) { + if (interpreter_handle != 0) { + delete convertLongToInterpreter(env, interpreter_handle); + } + if (model_handle != 0) { + delete convertLongToModel(env, model_handle); + } + if (error_handle != 0) { + delete convertLongToErrorReporter(env, error_handle); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..430886b7cc04a356d1826843acc1bbebf4189bf7 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -0,0 +1,151 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" +#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +// This is to be provided at link-time by a library. +extern std::unique_ptr CreateOpResolver(); +} // namespace tflite + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)[Ljava/lang/Object; + */ +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)[Ljava/lang/Object; + */ +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JZ) + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, + jclass clazz, + jlong handle, + jboolean state); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (I)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( + JNIEnv* env, jclass clazz, jint size); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (Ljava/lang/String;J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( + JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (Ljava/lang/Object;J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( + JNIEnv* env, jclass clazz, jobject model_buffer, jlong error_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( + JNIEnv* env, jclass clazz, jlong model_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, + jobjectArray values); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JII)[I + * + * It gets input dimensions if num_bytes matches number of bytes required by + * the input, else returns null and throws IllegalArgumentException. + */ +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( + JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJI[I) + * + * It resizes dimensions of a input. + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jint input_idx, jintArray dims); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJJ) + */ +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( + JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle, + jlong interpreter_handle); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..65126e78a3003f8a69c69326124d613e878c0f9d --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -0,0 +1,242 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" +#include +#include +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" + +namespace { + +TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to TfLiteTensor."); + return nullptr; + } + return reinterpret_cast(handle); +} + +size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, + void* dst, size_t dst_size) { + jarray array = static_cast(object); + const int num_elements = env->GetArrayLength(array); + size_t to_copy = num_elements * elementByteSize(type); + if (to_copy > dst_size) { + throwException(env, kIllegalStateException, + "cannot write Java array of %d bytes to Tensor of %d bytes", + to_copy, dst_size); + return 0; + } + switch (type) { + case kTfLiteFloat32: { + jfloatArray a = static_cast(array); + jfloat* values = env->GetFloatArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseFloatArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteInt32: { + jintArray a = static_cast(array); + jint* values = env->GetIntArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseIntArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteInt64: { + jlongArray a = static_cast(array); + jlong* values = env->GetLongArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseLongArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteUInt8: { + jbyteArray a = static_cast(array); + jbyte* values = env->GetByteArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseByteArrayElements(a, values, JNI_ABORT); + return to_copy; + } + default: { + throwException(env, kUnsupportedOperationException, + "TensorFlowLite currently supports float (32 bits), " + "int (32 bits), byte (8 bits), and long (64 bits), " + "support for other types (DataType %d in this case) will " + "be added in the future", + kTfLiteFloat32, type); + return 0; + } + } +} + +size_t readOneDimensionalArray(JNIEnv* env, TfLiteType data_type, + const void* src, size_t src_size, jarray dst) { + const int len = env->GetArrayLength(dst); + const size_t size = len * elementByteSize(data_type); + if (size > src_size) { + throwException( + env, kIllegalStateException, + "cannot fill a Java array of %d bytes with a Tensor of %d bytes", size, + src_size); + return 0; + } + switch (data_type) { + case kTfLiteFloat32: { + jfloatArray float_array = static_cast(dst); + env->SetFloatArrayRegion(float_array, 0, len, + static_cast(src)); + return size; + } + case kTfLiteInt32: { + jintArray int_array = static_cast(dst); + env->SetIntArrayRegion(int_array, 0, len, static_cast(src)); + return size; + } + case kTfLiteInt64: { + jlongArray long_array = static_cast(dst); + env->SetLongArrayRegion(long_array, 0, len, + static_cast(src)); + return size; + } + case kTfLiteUInt8: { + jbyteArray byte_array = static_cast(dst); + env->SetByteArrayRegion(byte_array, 0, len, + static_cast(src)); + return size; + } + default: { + throwException(env, kIllegalStateException, "invalid DataType(%d)", + data_type); + } + } + return 0; +} + +size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src, + size_t src_size, int dims_left, jarray dst) { + if (dims_left == 1) { + return readOneDimensionalArray(env, data_type, src, src_size, dst); + } else { + jobjectArray ndarray = static_cast(dst); + int len = env->GetArrayLength(ndarray); + size_t size = 0; + for (int i = 0; i < len; ++i) { + jarray row = static_cast(env->GetObjectArrayElement(ndarray, i)); + size += readMultiDimensionalArray(env, data_type, src + size, + src_size - size, dims_left - 1, row); + env->DeleteLocalRef(row); + if (env->ExceptionCheck()) return size; + } + return size; + } +} + +} // namespace + +size_t elementByteSize(TfLiteType data_type) { + // The code in this file makes the assumption that the + // TensorFlow TF_DataTypes and the Java primitive types + // have the same byte sizes. Validate that: + switch (data_type) { + case kTfLiteFloat32: + static_assert(sizeof(jfloat) == 4, + "Java float not compatible with kTfLiteFloat"); + return 4; + case kTfLiteInt32: + static_assert(sizeof(jint) == 4, + "Java int not compatible with kTfLiteInt"); + return 4; + case kTfLiteUInt8: + static_assert(sizeof(jbyte) == 1, + "Java byte not compatible with kTfLiteUInt8"); + return 1; + case kTfLiteInt64: + static_assert(sizeof(jlong) == 8, + "Java long not compatible with kTfLiteInt64"); + return 8; + default: + return 0; + } +} + +size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) { + char* buf = static_cast(env->GetDirectBufferAddress(object)); + if (!buf) { + throwException(env, kIllegalArgumentException, + "Input ByteBuffer is not a direct buffer"); + return 0; + } + *dst = buf; + return dst_size; +} + +size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, + int dims_left, char** dst, int dst_size) { + if (dims_left <= 1) { + return writeOneDimensionalArray(env, src, type, *dst, dst_size); + } else { + jobjectArray ndarray = static_cast(src); + int len = env->GetArrayLength(ndarray); + size_t sz = 0; + for (int i = 0; i < len; ++i) { + jobject row = env->GetObjectArrayElement(ndarray, i); + char* next_dst = *dst + sz; + sz += writeMultiDimensionalArray(env, row, type, dims_left - 1, &next_dst, + dst_size - sz); + env->DeleteLocalRef(row); + if (env->ExceptionCheck()) return sz; + } + return sz; + } +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject value) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + int num_dims = tensor->dims->size; + if (num_dims == 0) { + throwException(env, kIllegalArgumentException, + "copyTo() is not meant for scalar Tensors."); + return; + } + readMultiDimensionalArray(env, tensor->type, tensor->data.raw, tensor->bytes, + num_dims, static_cast(value)); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, + jclass clazz, + jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return 0; + return static_cast(tensor->type); +} + +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return nullptr; + int num_dims = tensor->dims->size; + jintArray result = env->NewIntArray(num_dims); + jint* dims = env->GetIntArrayElements(result, nullptr); + for (int i = 0; i < num_dims; ++i) { + dims[i] = static_cast(tensor->dims->data[i]); + } + env->ReleaseIntArrayElements(result, dims, 0); + return result; +} diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..3a4910dcc3a719fbb9f365dae693423de768349c --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ + +#include +#include "tensorflow/contrib/lite/context.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (J)[I + */ +JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (JLjava/lang/Object;) + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject value); + +/* + * Finds the size of each data type. + */ +size_t elementByteSize(TfLiteType data_type); + +/* + * Writes data of a ByteBuffer into dest. + */ +size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size); + +/* + * Writes a multi-dimensional array into dest. + */ +size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, + int dims_left, char** dst, int dst_size); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e7f2f56921b871a6ace2b6cb984fcd185a4d2ab --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc @@ -0,0 +1,26 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h" +#include "tensorflow/contrib/lite/version.h" + +JNIEXPORT jstring JNICALL +Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv* env, jclass /*clazz*/) { + char buf[64]; + snprintf(buf, sizeof(buf), "%d", TFLITE_SCHEMA_VERSION); + return env->NewStringUTF(buf); +} diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..65f8341149287f151f7e51fe04d9525bf119164e --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_TensorFlowLite + * Method: version + * Signature: ()Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL +Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv*, jclass); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/version_script.lds b/tensorflow/contrib/lite/java/src/main/native/version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..38c93dda730550070f28b59297c5191a9615ed7b --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/version_script.lds @@ -0,0 +1,11 @@ +VERS_1.0 { + # Export JNI symbols. + global: + Java_*; + JNI_OnLoad; + JNI_OnUnload; + + # Hide everything else. + local: + *; +}; diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java new file mode 100644 index 0000000000000000000000000000000000000000..cebc9442008e10e7674cf7b1dc58e633fef4ba39 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.DataType}. */ +@RunWith(JUnit4.class) +public final class DataTypeTest { + + @Test + public void testElemByteSize() { + assertThat(DataType.FLOAT32.elemByteSize()).isEqualTo(4); + assertThat(DataType.INT32.elemByteSize()).isEqualTo(4); + assertThat(DataType.UINT8.elemByteSize()).isEqualTo(1); + assertThat(DataType.INT64.elemByteSize()).isEqualTo(8); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java new file mode 100644 index 0000000000000000000000000000000000000000..424b3de6c97672e310c54230a7ac1204f46d9ac8 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -0,0 +1,221 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.io.File; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.Interpreter}. */ +@RunWith(JUnit4.class) +public final class InterpreterTest { + + private static final File MODEL_FILE = + new File("tensorflow/contrib/lite/java/src/testdata/add.bin"); + + private static final File MOBILENET_MODEL_FILE = + new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin"); + + @Test + public void testInterpreter() throws Exception { + Interpreter interpreter = new Interpreter(MODEL_FILE); + assertThat(interpreter).isNotNull(); + interpreter.close(); + } + + @Test + public void testRunWithMappedByteBufferModel() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + MappedByteBuffer mappedByteBuffer = + fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); + Interpreter interpreter = new Interpreter(mappedByteBuffer); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + fileChannel.close(); + } + + @Test + public void testRun() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + Float[] oneD = {1.23f, 6.54f, 7.81f}; + Float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + Float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + Float[][][][] fourD = {threeD, threeD}; + Float[][][][] parsedOutputs = new Float[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [[[[Ljava.lang.Float;"); + } + interpreter.close(); + } + + @Test + public void testRunWithBoxedInputs() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + } + + @Test + public void testRunForMultipleInputsOutputs() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + interpreter.runForMultipleInputsOutputs(inputs, outputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + } + + @Test + public void testMobilenetRun() { + // Create a gray image. + float[][][][] img = new float[1][224][224][3]; + for (int i = 0; i < 224; ++i) { + for (int j = 0; j < 224; ++j) { + img[0][i][j][0] = 0.5f; + img[0][i][j][1] = 0.5f; + img[0][i][j][2] = 0.5f; + } + } + + // Allocate memory to receive the output values. + float[][] labels = new float[1][1001]; + + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + interpreter.run(img, labels); + interpreter.close(); + + assertThat(labels[0]) + .usingExactEquality() + .containsNoneOf(new float[] {Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY}); + } + + @Test + public void testRunWithWrongInputType() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + int[] oneD = {4, 3, 9}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "DataType (2) of input data does not match with the DataType (1) of model inputs."); + } + interpreter.close(); + } + + @Test + public void testRunWithWrongOutputType() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + int[][][][] parsedOutputs = new int[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Cannot convert an TensorFlowLite tensor with type " + + "FLOAT32 to a Java object of type [[[[I (which is compatible with the" + + " TensorFlowLite type INT32)"); + } + interpreter.close(); + } + + @Test + public void testGetInputIndex() { + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + try { + interpreter.getInputIndex("WrongInputName"); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "WrongInputName is not a valid name for any input. The indexes of the inputs" + + " are {input=0}"); + } + int index = interpreter.getInputIndex("input"); + assertThat(index).isEqualTo(0); + } + + @Test + public void testGetOutputIndex() { + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + try { + interpreter.getOutputIndex("WrongOutputName"); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "WrongOutputName is not a valid name for any output. The indexes of the outputs" + + " are {MobilenetV1/Predictions/Softmax=0}"); + } + int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax"); + assertThat(index).isEqualTo(0); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java new file mode 100644 index 0000000000000000000000000000000000000000..9a6894f49c0b7278511717d2671648c6d1763e00 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -0,0 +1,406 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.NativeInterpreterWrapper}. */ +@RunWith(JUnit4.class) +public final class NativeInterpreterWrapperTest { + + private static final String FLOAT_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/add.bin"; + + private static final String INT_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/int32.bin"; + + private static final String LONG_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/int64.bin"; + + private static final String BYTE_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/uint8.bin"; + + private static final String INVALID_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin"; + + @Test + public void testConstructor() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + assertThat(wrapper).isNotNull(); + wrapper.close(); + } + + @Test + public void testConstructorWithInvalidModel() { + try { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("Model provided has model identifier ' is ', should be 'TFL3'"); + } + } + + @Test + public void testRunWithFloat() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, -6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[2][8][8][3]; + outputs[0].copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + wrapper.close(); + } + + @Test + public void testRunWithInt() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH); + int[] oneD = {3, 7, -4}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + int[][][][] parsedOutputs = new int[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + int[] outputOneD = parsedOutputs[0][0][0]; + int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithLong() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH); + long[] oneD = {-892834092L, 923423L, 2123918239018L}; + long[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + long[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + long[][][][] parsedOutputs = new long[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + long[] outputOneD = parsedOutputs[0][0][0]; + long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, + -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByte() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + byte[] oneD = {(byte) 0xe0, 0x4f, (byte) 0xd0}; + byte[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + byte[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + byte[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + int[] inputDims = {2, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + byte[][][][] parsedOutputs = new byte[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + byte[] outputOneD = parsedOutputs[0][0][0]; + byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingBytes() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 8 * 8 * 3); + bbuf.order(ByteOrder.nativeOrder()); + bbuf.rewind(); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + bbuf.put((byte) 0xe0); + bbuf.put((byte) 0x4f); + bbuf.put((byte) 0xd0); + } + } + } + Object[] inputs = {bbuf}; + int[] inputDims = {2, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + byte[][][][] parsedOutputs = new byte[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + byte[] outputOneD = parsedOutputs[0][0][0]; + byte[] expected = { + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0 + }; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingFloats() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(4 * 8 * 8 * 3 * 4); + bbuf.order(ByteOrder.nativeOrder()); + bbuf.rewind(); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + bbuf.putFloat(1.23f); + bbuf.putFloat(-6.54f); + bbuf.putFloat(7.81f); + } + } + } + Object[] inputs = {bbuf}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes"); + } + int[] inputDims = {4, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[4][8][8][3]; + outputs[0].copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingWrongSize() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3); + bbuf.order(ByteOrder.nativeOrder()); + Object[] inputs = {bbuf}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes."); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputType() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + int[] oneD = {4, 3, 9}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "DataType (2) of input data does not match with the DataType (1) of model inputs."); + } + wrapper.close(); + } + + @Test + public void testRunAfterClose() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + wrapper.close(); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter."); + } + } + + @Test + public void testRunWithEmptyInputs() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + try { + Object[] inputs = {}; + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("Invalid inputs. Inputs should not be null or empty."); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputSize() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD, fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2"); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputNumOfDims() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + Object[] inputs = {threeD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("0-th input should have 4 dimensions, but found 3 dimensions"); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputDims() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + } + wrapper.close(); + } + + @Test + public void testNumElements() { + int[] shape = {2, 3, 4}; + int num = NativeInterpreterWrapper.numElements(shape); + assertThat(num).isEqualTo(24); + shape = null; + num = NativeInterpreterWrapper.numElements(shape); + assertThat(num).isEqualTo(0); + } + + @Test + public void testIsNonEmtpyArray() { + assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse(); + assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse(); + int[] emptyArray = {}; + assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse(); + int[] validArray = {9, 5, 2, 1}; + assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue(); + } + + @Test + public void testDataTypeOf() { + float[] testEmtpyArray = {}; + DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[] testFloatArray = {0.783f, 0.251f}; + dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; + dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + try { + double[] testDoubleArray = {0.783, 0.251}; + NativeInterpreterWrapper.dataTypeOf(testDoubleArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); + } + try { + Float[] testBoxedArray = {0.783f, 0.251f}; + NativeInterpreterWrapper.dataTypeOf(testBoxedArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); + } + } + + @Test + public void testNumDimensions() { + int scalar = 1; + assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0); + int[][] array = {{2, 4}, {1, 9}}; + assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2); + try { + int[] emptyArray = {}; + NativeInterpreterWrapper.numDimensions(emptyArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("array lengths cannot be 0."); + } + } + + @Test + public void testFillShape() { + int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; + int num = NativeInterpreterWrapper.numDimensions(array); + int[] shape = new int[num]; + NativeInterpreterWrapper.fillShape(array, 0, shape); + assertThat(num).isEqualTo(3); + assertThat(shape[0]).isEqualTo(2); + assertThat(shape[1]).isEqualTo(3); + assertThat(shape[2]).isEqualTo(1); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java new file mode 100644 index 0000000000000000000000000000000000000000..665c937cb60ad957c0030c01eb57899754c80bf8 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.TensorFlowLite}. */ +@RunWith(JUnit4.class) +public final class TensorFlowLiteTest { + + @Test + public void testVersion() { + assertThat(TensorFlowLite.version()).isEqualTo("3"); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java new file mode 100644 index 0000000000000000000000000000000000000000..94b6632bb8dd7117bf4074da1939bd23ce732efd --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.Tensor}. */ +@RunWith(JUnit4.class) +public final class TensorTest { + + private static final String MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/add.bin"; + + private NativeInterpreterWrapper wrapper; + private long nativeHandle; + + @Before + public void setUp() { + wrapper = new NativeInterpreterWrapper(MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + nativeHandle = outputs[0].nativeHandle; + } + + @After + public void tearDown() { + wrapper.close(); + } + + @Test + public void testFromHandle() throws Exception { + Tensor tensor = Tensor.fromHandle(nativeHandle); + assertThat(tensor).isNotNull(); + int[] expectedShape = {2, 8, 8, 3}; + assertThat(tensor.shapeCopy).isEqualTo(expectedShape); + assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32); + } + + @Test + public void testCopyTo() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + float[][][][] parsedOutputs = new float[2][8][8][3]; + tensor.copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + + @Test + public void testCopyToWrongType() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + int[][][][] parsedOutputs = new int[2][8][8][3]; + try { + tensor.copyTo(parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Cannot convert an TensorFlowLite tensor with type " + + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite " + + "type INT32)"); + } + } + + @Test + public void testCopyToWrongShape() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + float[][][][] parsedOutputs = new float[1][8][8][3]; + try { + tensor.copyTo(parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Shape of output target [1, 8, 8, 3] does not match " + + "with the shape of the Tensor [2, 8, 8, 3]."); + } + } +} diff --git a/tensorflow/contrib/lite/java/src/testdata/add.bin b/tensorflow/contrib/lite/java/src/testdata/add.bin new file mode 100644 index 0000000000000000000000000000000000000000..aef0fe3d82c9d92dc444076d3b46e05af1923f46 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/add.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/float32.bin b/tensorflow/contrib/lite/java/src/testdata/float32.bin new file mode 100644 index 0000000000000000000000000000000000000000..30b1264ca152740e1607651ce6cbc2a548319bc3 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/float32.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/int32.bin b/tensorflow/contrib/lite/java/src/testdata/int32.bin new file mode 100644 index 0000000000000000000000000000000000000000..f6f3cf607a249e096921b12d848c4055a37d1168 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/int32.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/int64.bin b/tensorflow/contrib/lite/java/src/testdata/int64.bin new file mode 100644 index 0000000000000000000000000000000000000000..c12aa41ca7be49b30db291a25156bd20cbab21a9 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/int64.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin b/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..8156ac741cbc0aa32e6d867ad09b5e6be8451868 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin @@ -0,0 +1 @@ +This is an invalid model. \ No newline at end of file diff --git a/tensorflow/contrib/lite/java/src/testdata/uint8.bin b/tensorflow/contrib/lite/java/src/testdata/uint8.bin new file mode 100644 index 0000000000000000000000000000000000000000..f06c5cf58462ce56b012d163fb208329874f83ad Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/uint8.bin differ diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..2b4f37bc6cfe1dbc0c178a56b892f545e8ad4f3b --- /dev/null +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -0,0 +1,30 @@ +# Description: +# Internal helper function to test TF Lite API. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +android_library( + name = "testhelper", + srcs = glob( + [ + "*.java", + ], + ), + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite_java", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java new file mode 100644 index 0000000000000000000000000000000000000000..8660cabf709e6531a5667a16e5cf43a93c7135bd --- /dev/null +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** A helper class for internal tests. */ +public class TestHelper { + + /** + * Turns on/off NNAPI of an {@code Interpreter}. + * + * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code + * IllegalArgumentException} will be thrown. + * @param useNNAPI a boolean value indicating to turn on or off NNAPI. + */ + public static void setUseNNAPI(Interpreter interpreter, boolean useNNAPI) { + if (interpreter != null && interpreter.wrapper != null) { + interpreter.wrapper.setUseNNAPI(useNNAPI); + } else { + throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI."); + } + } +} diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..83eb7f2cb84d5cc1c2af1f388502c056425476b8 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -0,0 +1,422 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +tf_cc_test( + name = "optional_tensor_test", + size = "small", + srcs = ["optional_tensor_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/testing:util", + "//tensorflow/core:lib", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "gemm_support", + srcs = [ + "gemm_support.cc", + ], + hdrs = [ + "gemm_support.h", + ], + copts = tflite_copts(), + deps = [ + ":op_macros", + "//tensorflow/contrib/lite:context", + "@gemmlowp//:gemmlowp", + ], +) + +cc_library( + name = "activation_functor", + hdrs = [ + "activation_functor.h", + ], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + ], +) + +cc_library( + name = "op_macros", + hdrs = [ + "op_macros.h", + ], +) + +cc_library( + name = "builtin_ops", + srcs = [ + "activations.cc", + "add.cc", + "basic_rnn.cc", + "concatenation.cc", + "conv.cc", + "depthwise_conv.cc", + "embedding_lookup.cc", + "embedding_lookup_sparse.cc", + "fully_connected.cc", + "hashtable_lookup.cc", + "kernel_util.cc", + "l2norm.cc", + "local_response_norm.cc", + "lsh_projection.cc", + "lstm.cc", + "mul.cc", + "pad.cc", + "pooling.cc", + "register.cc", + "reshape.cc", + "resize_bilinear.cc", + "skip_gram.cc", + "space_to_depth.cc", + "svdf.cc", + ], + hdrs = [ + "kernel_util.h", + "padding.h", + "register.h", + ], + # Suppress warnings that are introduced by Eigen Tensor. + copts = tflite_copts() + [ + "-Wno-error=reorder", + ] + select({ + "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], + "//conditions:default": [ + ], + }), + deps = [ + ":activation_functor", + ":op_macros", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels/internal:optimized", + "//tensorflow/contrib/lite/kernels/internal:optimized_base", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", + "//tensorflow/contrib/lite/kernels/internal:reference", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:round", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "@farmhash_archive//:farmhash", + ], +) + +tf_cc_test( + name = "activations_test", + size = "small", + srcs = ["activations_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "add_test", + size = "small", + srcs = ["add_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "concatenation_test", + size = "small", + srcs = ["concatenation_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "conv_test", + size = "small", + srcs = ["conv_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "depthwise_conv_test", + size = "small", + srcs = ["depthwise_conv_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "basic_rnn_test", + size = "small", + srcs = ["basic_rnn_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "l2norm_test", + size = "small", + srcs = ["l2norm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "mul_test", + size = "small", + srcs = ["mul_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "pad_test", + size = "small", + srcs = ["pad_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "reshape_test", + size = "small", + srcs = ["reshape_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "resize_bilinear_test", + size = "small", + srcs = ["resize_bilinear_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "svdf_test", + size = "small", + srcs = ["svdf_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "embedding_lookup_test", + size = "small", + srcs = ["embedding_lookup_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "embedding_lookup_sparse_test", + size = "small", + srcs = ["embedding_lookup_sparse_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "fully_connected_test", + size = "small", + srcs = ["fully_connected_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "local_response_norm_test", + size = "small", + srcs = ["local_response_norm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "pooling_test", + size = "small", + srcs = ["pooling_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "softmax_test", + size = "small", + srcs = ["softmax_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "lsh_projection_test", + size = "small", + srcs = ["lsh_projection_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "hashtable_lookup_test", + size = "small", + srcs = ["hashtable_lookup_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "lstm_test", + size = "small", + srcs = ["lstm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "skip_gram_test", + size = "small", + srcs = ["skip_gram_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "space_to_depth_test", + size = "small", + srcs = ["space_to_depth_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..cfb3369e991a474315424423fe655ba214edabbc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activation_functor.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { + +// Dynamic (non-fused) activation functor. perhaps it is worth having +// template instantiation? +// TODO(aselle): Make this more efficient by pulling the switch to conv_eval +// using template inlining. +class ActivationFunctor { + public: + explicit ActivationFunctor(TfLiteFusedActivation act) : act_(act) {} + + float operator()(float a) const { + switch (act_) { + case kTfLiteActNone: + return a; + case kTfLiteActRelu: + return a < 0.f ? 0.f : a; + case kTfLiteActRelu6: + return std::max(0.f, std::min(a, 6.f)); + case kTfLiteActTanh: + return std::tanh(a); + case kTfLiteActSigmoid: + return 1.0f / (1.0f + std::exp(-a)); + default: + // TODO(aselle): More informative fatal error! + exit(1); + } + } + + private: + TfLiteFusedActivation act_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ab60a33e5e2ff61bae5f4c6db85ab9c47a391bc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -0,0 +1,389 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace activations { + +struct OpData { + int32_t input_multiplier = 0; + int input_left_shift = 0; + int32_t input_range_radius = 0; + int diff_min = 0; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + TF_LITE_ENSURE(context, output->params.scale == 1. / 256); + + static constexpr int kInputIntegerBits = 4; + + const double input_real_multiplier = + input->params.scale * + static_cast(1 << (31 - kInputIntegerBits)); + + QuantizeMultiplierGreaterThanOne(input_real_multiplier, + &data->input_multiplier, + &data->input_left_shift); + data->input_range_radius = + CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TF_LITE_ENSURE(context, + NumDimensions(input) == 2 || NumDimensions(input) == 4); + + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + TF_LITE_ENSURE(context, output->params.scale == 1. / 256); + + static const int kScaledDiffIntegerBits = 5; + + tflite::PreprocessSoftmaxScaling( + params->beta, input->params.scale, kScaledDiffIntegerBits, + &data->input_multiplier, &data->input_left_shift); + data->diff_min = -1.0 * tflite::CalculateInputRadius( + kScaledDiffIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::max(0.f, *in); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) { + *out = std::min(std::max(-1.f, *in), 1.f); + } + return kTfLiteOk; + } break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::tanh(*in); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +// Sigmoid is also know as "Logistic". +TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in)); + break; + } + case kTfLiteUInt8: { + optimized_ops::Logistic( + GetTensorData(input), GetTensorDims(input), + input->params.zero_point, data->input_range_radius, + data->input_multiplier, data->input_left_shift, + GetTensorData(output), GetTensorDims(output)); + break; + } + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } + return kTfLiteOk; +} + +// Takes a 2D tensor and perform softmax along the second dimension. +void Softmax2DFloat(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + float* in = input->data.f; + float* out = output->data.f; + TF_LITE_ASSERT(input_size > 0); + + // For each batch + for (int b = 0; b < batch_size; b++) { + // Find the max coeff. + float max_coeff = in[0]; + for (int i = 1; i < input_size; i++) { + if (in[i] > max_coeff) max_coeff = in[i]; + } + + // Compute the normalized sum of exps. + float exp_sum = 0.0; + for (int i = 0; i < input_size; i++) { + out[i] = std::exp((in[i] - max_coeff) * params->beta); + exp_sum += out[i]; + } + + // Divide by the sum of exps. + float reciprocal_sum_exp = 1.f / exp_sum; + for (int i = 0; i < input_size; i++) { + out[i] *= reciprocal_sum_exp; + } + + // Advance in and out pointers for the next batch. + in += input_size; + out += input_size; + } +} + +void Softmax2DQuantized(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + // TODO(ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 2D + // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X, + // 1, 1, Y) shape. + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + optimized_ops::Softmax(GetTensorData(input), + GetTensorDims({batch_size, 1, 1, input_size}), + data->input_multiplier, data->input_left_shift, + data->diff_min, GetTensorData(output), + GetTensorDims({batch_size, 1, 1, input_size})); +} + +// Takes a 4D tensor and perform softmax along the forth dimension. +void Softmax4DFloat(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + params->beta, GetTensorData(output), + GetTensorDims(output)); +} + +void Softmax4DQuantized(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + data->input_multiplier, data->input_left_shift, + data->diff_min, GetTensorData(output), + GetTensorDims(output)); +} + +TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + + // TODO(ahentz): consider an implementation that works for many (all?) + // dimensions. + switch (input->type) { + case kTfLiteFloat32: { + if (NumDimensions(input) == 2) { + Softmax2DFloat(input, output, params); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DFloat(input, output, params); + return kTfLiteOk; + } + context->ReportError(context, + "Only 2D and 4D tensors supported currently."); + return kTfLiteError; + } + case kTfLiteUInt8: { + if (NumDimensions(input) == 2) { + Softmax2DQuantized(input, output, params, data); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DQuantized(input, output, params, data); + return kTfLiteOk; + } + context->ReportError(context, + "Only 2D and 4D tensors supported currently."); + return kTfLiteError; + } + default: + context->ReportError(context, + "Only float32 and uint8_t supported currently."); + return kTfLiteError; + } +} + +} // namespace activations + +TfLiteRegistration* Register_RELU() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::ReluEval}; + return &r; +} + +TfLiteRegistration* Register_RELU1() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::Relu1Eval}; + return &r; +} + +TfLiteRegistration* Register_RELU6() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::Relu6Eval}; + return &r; +} + +TfLiteRegistration* Register_TANH() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::TanhEval}; + return &r; +} + +TfLiteRegistration* Register_LOGISTIC() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::SigmoidPrepare, + activations::SigmoidEval}; + return &r; +} + +TfLiteRegistration* Register_SOFTMAX() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::SoftmaxPrepare, + activations::SoftmaxEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..33ca56e745c043efd12b851af14f273fb273d577 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -0,0 +1,323 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseActivationsOpModel : public SingleOpModel { + public: + // Most activations don't take any options, so this constructor works for + // them. + BaseActivationsOpModel(BuiltinOperator type, TensorData input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else { + output_ = AddOutput({input.type, {}}); + } + SetBuiltinOp(type, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_)}); + } + + // A dedicated constructor for SOFTMAX, which does some options. + BaseActivationsOpModel(float softmax_beta, TensorData input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else { + output_ = AddOutput({input.type, {}}); + } + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, softmax_beta).Union()); + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +// TODO(ahentz): I don't quite understand the tradeoffs in the quantized +// implementation of sigmoid and software, but a tolerance of twice the output +// scale seems reasonable. We might want to change this if we have a better +// theoretical bound. +const float kQuantizedTolerance = 2 * (1. / 256); + +class QuantizedActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatActivationsOpTest, Relu) { + FloatActivationsOpModel m(BuiltinOperator_RELU, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 10, 1, // + })); +} + +TEST(FloatActivationsOpTest, Relu1) { + FloatActivationsOpModel m(BuiltinOperator_RELU1, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -2.0, 1.1, -0.1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -1.0, 1.0, -0.1, // + })); +} + +TEST(FloatActivationsOpTest, Relu6) { + FloatActivationsOpModel m(BuiltinOperator_RELU6, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 6, 1, // + })); +} + +TEST(FloatActivationsOpTest, Tanh) { + FloatActivationsOpModel m(BuiltinOperator_TANH, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0, -0.9999877, 0.9640275, 0.999329, // + 0.99505475, -0.9640275, 1, 0.7615941, // + }))); +} + +TEST(FloatActivationsOpTest, Sigmoid) { + FloatActivationsOpModel m(BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }))); +} + +TEST(QuantizedActivationsOpTest, Sigmoid) { + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); +} + +TEST(FloatActivationsOpTest, Softmax4D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(0.1, + /*input=*/{TensorType_FLOAT32, {4, 1, 1, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }))); +} + +TEST(QuantizedActivationsOpTest, Softmax4D) { + QuantizedActivationsOpModel m( + 0.1, + /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2( + 0.1, + /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +TEST(FloatActivationsOpTest, Softmax2D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {2, 4}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(0.1, + /*input=*/{TensorType_FLOAT32, {4, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }))); +} + +TEST(QuantizedActivationsOpTest, Softmax2D) { + QuantizedActivationsOpModel m(0.1, + /*input=*/{TensorType_UINT8, {2, 4}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2(0.1, + /*input=*/{TensorType_UINT8, {4, 2}, -10, 10}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e10a249abac3ba19cf107e055aa71d1eee00122 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace add { + +// This file has three implementation of Add. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_ADD(type) \ + type::Add(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops); + } else { + TF_LITE_ADD(optimized_ops); + } +#undef TF_LITE_ADD +} + +template +void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + auto input1_offset = -input1->params.zero_point; + auto input2_offset = -input2->params.zero_point; + auto output_offset = output->params.zero_point; + const int left_shift = 20; + const double twice_max_input_scale = + 2 * std::max(input1->params.scale, input2->params.scale); + const double real_input1_multiplier = + input1->params.scale / twice_max_input_scale; + const double real_input2_multiplier = + input2->params.scale / twice_max_input_scale; + const double real_output_multiplier = + twice_max_input_scale / ((1 << left_shift) * output->params.scale); + + int32 input1_multiplier; + int input1_shift; + QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, + &input1_shift); + int32 input2_multiplier; + int input2_shift; + QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, + &input2_shift); + int32 output_multiplier; + int output_shift; + QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, + &output_shift); + + int32 output_activation_min, output_activation_max; + CalculateActivationRangeUint8(params->activation, output, + &output_activation_min, &output_activation_max); + +#define TF_LITE_ADD(type) \ + type::BroadcastAdd( \ + left_shift, GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, input1_multiplier, input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), input2_offset, \ + input2_multiplier, input2_shift, output_offset, output_multiplier, \ + output_shift, output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops); + } else { + TF_LITE_ADD(optimized_ops); + } +#undef TF_LITE_ADD +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalAddFloat(context, node, params, input1, input2, output); + } else if (output->type == kTfLiteUInt8) { + EvalAddQuantized(context, node, params, input1, input2, + output); + } else { + context->ReportError(context, + "Inputs and outputs not all float|unit8 types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace add + +TfLiteRegistration* Register_ADD_REF() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD() { +#ifdef USE_NEON + return Register_ADD_NEON_OPT(); +#else + return Register_ADD_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ddf45bb576755d57d50c9e6e01bf50f15612c56d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -0,0 +1,170 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseAddOpModel : public SingleOpModel { + public: + BaseAddOpModel(const TensorData& input, const TensorData& output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input); + input2_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + protected: + int input1_; + int input2_; + int output_; +}; + +class FloatAddOpModel : public BaseAddOpModel { + public: + using BaseAddOpModel::BaseAddOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedAddOpModel : public BaseAddOpModel { + public: + using BaseAddOpModel::BaseAddOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// for quantized Add, the error shouldn't exceed 2*step +float GetTolerance(int min, int max) { + float kQuantizedStep = (max - min) / 255.0; + float kQuantizedTolerance = 2.0 * kQuantizedStep; + return kQuantizedTolerance; +} + +TEST(FloatAddOpModel, NoActivation) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +TEST(FloatAddOpModel, ActivationRELU1) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 0.4, 1.0, 1.0})); +} + +TEST(FloatAddOpModel, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-1.9, 0.4, 1.0, 1.3, 2.2, 2.1})) + << "With shape number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector> inputs1 = { + {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = { + {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = { + {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU1) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, + {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = {{0.6, 0.4, 0.9, -0.8}, + {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = {{-0.2, 0.6, 1.0, -0.1}, + {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_RELU1); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { + float kQuantizedTolerance = GetTolerance(-3.0, 3.0); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.1, 0.3, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-1.9, 0.5, 1.0, 1.3, 2.2, 2.1}, + kQuantizedTolerance))) + << "With shape number " << i; + } +} + +} // namespace +} // namespace tflite +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc new file mode 100644 index 0000000000000000000000000000000000000000..3cee43c68b2a0af5a3fd84b33a980b74bb8f0cb4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace rnn { + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kRecurrentWeightsTensor = 2; +constexpr int kBiasTensor = 3; +constexpr int KHiddenStateTensor = 0; +constexpr int kOutputTensor = 1; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int batch_size = input->dims->data[0]; + const int num_units = input_weights->dims->data[0]; + TF_LITE_ASSERT_EQ(input->dims->data[1], input_weights->dims->data[1]); + TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); + + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[KHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Resize state. + TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); + hidden_state_size_array->data[0] = batch_size; + hidden_state_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state, + hidden_state_size_array)); + + // Mark hidden state as a persistent tensor. + hidden_state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, + output_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[KHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Initialize the pointer bias. + const float* bias_ptr = bias->data.f; + + const int batch_size = input->dims->data[0]; + const int num_units = input_weights->dims->data[0]; + const int input_size = input->dims->data[1]; + const int input_weights_stride = input_weights->dims->data[1]; + const int recurrent_weights_stride = recurrent_weights->dims->data[1]; + + // For each batch + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to input, output and bias. + const float* input_ptr_batch = input->data.f + b * input_size; + float* output_ptr_batch = output->data.f + b * num_units; + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; + + // Output = bias + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = bias_ptr[o]; + } + + // Output += input * input_weights + for (int o = 0; o < num_units; o++) { + for (int i = 0; i < input_size; i++) { + output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; + } + input_weights_ptr += input_weights_stride; + } + + // Output += recurrent_weights * hidden_state + for (int o = 0; o < num_units; o++) { + for (int h = 0; h < num_units; h++) { + output_ptr_batch[o] += + hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; + } + recurrent_weights_ptr += recurrent_weights_stride; + } + + // Output = activation(Output) and update hidden_state + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = + (ActivationFunctor(params->activation))(output_ptr_batch[o]); + hidden_state_ptr_batch[o] = output_ptr_batch[o]; + } + } + + return kTfLiteOk; +} + +} // namespace rnn + +TfLiteRegistration* Register_RNN() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + rnn::Prepare, rnn::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ecccb985e91238f1183c8f94a2b5f468758ce55 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -0,0 +1,267 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite RNN op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0 +}; + +class RNNOpModel : public SingleOpModel { + public: + RNNOpModel(int batches, int units, int size) + : batches_(batches), units_(units), input_size_(size) { + input_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(TensorType_FLOAT32); + recurrent_weights_ = AddInput(TensorType_FLOAT32); + bias_ = AddInput(TensorType_FLOAT32); + hidden_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RNN, BuiltinOptions_RNNOptions, + CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); + BuildInterpreter({{batches_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + PopulateTensor(recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenState() { + const int zero_buffer_size = units_ * batches_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(hidden_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + private: + int input_; + int weights_; + int recurrent_weights_; + int bias_; + int hidden_state_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(FullyConnectedOpTest, BlackBoxTest) { + RNNOpModel rnn(2, 16, 8); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, + -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, + 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, + -0.37609905}); + + rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}); + + rnn.ResetHiddenState(); + const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / + (rnn.input_size() * rnn.num_batches()); + + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(rnn.input_size(), batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output + i * rnn.num_units(); + float* golden_end = golden_start + rnn.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc new file mode 100644 index 0000000000000000000000000000000000000000..9e7a1233dac0f3cd02dc386f9d194597f38ca3b8 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -0,0 +1,200 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace concatenation { + +// This file has two implementation of Concatenation. +enum KernelType { + kReference, + kGenericOptimized, +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + int axis = params->axis; + int num_inputs = node->inputs->size; + + // The number of dimensions of the input tensors must match, and all + // dimensions except 'axis' must be equal. + TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]]; + TfLiteType input_type = t0->type; + TF_LITE_ENSURE(context, axis >= 0); + TF_LITE_ENSURE(context, axis < t0->dims->size); + + // TODO(ahentz): These are limitations of our implementation that could be + // removed with a bit of effort. + TF_LITE_ENSURE(context, t0->dims->size <= 4); + TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); + TF_LITE_ENSURE(context, + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + + // Output dimensions will match input dimensions, except 'axis', which + // will be the sum of inputs + int sum_axis = t0->dims->data[axis]; + for (int i = 1; i < num_inputs; ++i) { + TfLiteTensor* t = &context->tensors[node->inputs->data[i]]; + TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size); + TF_LITE_ENSURE_EQ(context, t->type, input_type); + if (input_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, t->params.zero_point, t0->params.zero_point); + TF_LITE_ENSURE_EQ(context, t->params.scale, t0->params.scale); + } + for (int d = 0; d < t0->dims->size; ++d) { + if (d == axis) { + sum_axis += t->dims->data[axis]; + } else { + TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]); + } + } + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size); + for (int d = 0; d < t0->dims->size; ++d) { + output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d]; + } + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TF_LITE_ENSURE_EQ(context, output->type, input_type); + if (input_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, + t0->params.zero_point); + TF_LITE_ENSURE_EQ(context, output->params.scale, t0->params.scale); + } + + return context->ResizeTensor(context, output, output_size); +} + +template +class VectorOfInputs { + public: + VectorOfInputs(const TfLiteContext& context, const TfLiteIntArray& inputs) { + int num_inputs = inputs.size; + + all_data_.reserve(num_inputs); + all_dims_.reserve(num_inputs); + all_dims_ptr_.reserve(num_inputs); + + for (int i = 0; i < num_inputs; ++i) { + TfLiteTensor* input = &context.tensors[inputs.data[i]]; + all_data_.push_back(GetTensorData(input)); + all_dims_.push_back(GetTensorDims(input)); + } + + // Taking the pointer from inside a std::vector is only OK if the vector is + // never modified, so we populate all_dims in the previous loop and then we + // are free to grab iterators here. + for (int i = 0; i < num_inputs; ++i) { + all_dims_ptr_.push_back(&all_dims_[i]); + } + } + const T* const* data() const { return all_data_.data(); } + const Dims<4>* const* dims() const { return all_dims_ptr_.data(); } + + private: + std::vector all_data_; + std::vector> all_dims_; + std::vector*> all_dims_ptr_; +}; + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + +// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should +// allocate and populate these during Prepare(). +// TODO(ycling): Activation function parameter is ignored. For now we dont have +// a model with a Concatenation with fused activation function. +#define TF_LITE_CONCATENATION(type, scalar) \ + VectorOfInputs all_inputs(*context, *node->inputs); \ + type::Concatenation( \ + RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \ + all_inputs.dims(), node->inputs->size, GetTensorData(output), \ + GetTensorDims(output)) + + switch (output->type) { // Already know in/outtypes are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, float); + } else { + TF_LITE_CONCATENATION(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, uint8_t); + } else { + TF_LITE_CONCATENATION(optimized_ops, uint8_t); + } + break; + default: + context->ReportError(context, + "Only float32 and uint8 are currently supported."); + return kTfLiteError; + } + +#undef TF_LITE_CONCATENATION + + return kTfLiteOk; +} + +#undef TF_LITE_MACRO_DISPATCH + +} // namespace concatenation + +TfLiteRegistration* Register_CONCATENATION_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONCATENATION() { + // TODO(ahentz): It turns out the two versions of Concatenation are almost + // identical, so we should consider removing one. + return Register_CONCATENATION_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..499856a93cbbfbf9aa1a326912e52ce32bbbdf83 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseConcatenationOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, axis, input + // dimensions. + BaseConcatenationOpModel(const TensorData& input_template, int axis, + int num_inputs) { + std::vector> all_input_shapes; + for (int i = 0; i < num_inputs; ++i) { + all_input_shapes.push_back(input_template.shape); + AddInput(input_template); + } + output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min, + input_template.max}); + SetBuiltinOp( + BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions, + CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE) + .Union()); + BuildInterpreter(all_input_shapes); + } + + protected: + int output_; +}; + +class ConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + PopulateTensor(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + QuantizeAndPopulate(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(ConcatenationOpTest, ThreeDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7})); +} + +TEST(ConcatenationOpTest, OneTrivialInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {1}}, /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {5.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ::testing::ElementsAre(5)); +} + +TEST(ConcatenationOpTest, TwoDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(ConcatenationOpTest, TwoInputsTwoAxis) { + // We will concatenate two tensors along different dimensions. + auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0, + /*num_inputs=*/2); + m0.SetInput(0, tensor0); + m0.SetInput(1, tensor1); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + + ConcatenationOpModel m1({TensorType_FLOAT32, {2, 3}}, /*axis=*/1, + /*num_inputs=*/2); + m1.SetInput(0, tensor0); + m1.SetInput(1, tensor1); + m1.Invoke(); + EXPECT_THAT(m1.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} + +TEST(ConcatenationOpTest, FourInputs) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2, + /*num_inputs=*/4); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + })); +} + +TEST(ConcatenationOpTest, FourInputsQuantized) { + QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}, + /*axis=*/2, + /*num_inputs=*/4); + + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + }))); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..c75c04baeac2ce53c6261d677dca8d72fafa0da5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -0,0 +1,425 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace conv { + +// This file has three implementation of Conv. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +struct OpData { + // IDs are the arbitrary identifiers used by TF Lite to identify and access + // memory buffers. + int im2col_id; + int hwcn_weights_id; + + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // Indexes are the offset to the memory buffer in the array used to keep track + // of the allocated temporaries. + int32_t im2col_index; + int32_t hwcn_weights_index; + bool need_hwcn_weights; + bool have_weights_been_transposed; + bool need_im2col; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to use as scratch space for im2col, and + // to carry information from Prepare() to Eval(). + auto* data = new OpData; + context->AddTensors(context, 1, &data->im2col_id); + context->AddTensors(context, 1, &data->hwcn_weights_id); + gemm_support::IncrementUsageCounter(context); + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); +} + +// Naive implementation of transpose for floats. Could be optimized to be more +// cache friendly, but for now it's a one-time cost on first run, and we would +// prefer to remove the need to do this at all eventually. +void TransposeFloatTensor(TfLiteTensor* input, TfLiteTensor* output) { + const int rows = output->dims->data[1]; + const int cols = output->dims->data[0]; + const float* input_data = GetTensorData(input); + float* output_data = GetTensorData(output); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + const float in_value = input_data[i * cols + j]; + output_data[j * rows + i] = in_value; + } + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + bool hasBias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; + // Check dimensionality of input, filter + TF_LITE_ENSURE_EQ(context, input->dims->size, 4); + TF_LITE_ENSURE_EQ(context, filter->dims->size, 4); + // Check input channels matching filter + TF_LITE_ENSURE_EQ(context, input->dims->data[3], filter->dims->data[3]); + + // Check types. (We assume that UINT8 refers to quantized tensors) + TfLiteType data_type = input->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, output->type, data_type); + TF_LITE_ENSURE_EQ(context, filter->type, data_type); + + TfLiteTensor* bias = nullptr; + + // TODO(ahentz): At this point the optimized versions require 'bias'. We can + // either change that or document that convolution requires it. + TF_LITE_ENSURE(context, hasBias); + + if (hasBias) { + bias = &context->tensors[node->inputs->data[2]]; + if (data_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); + } else { + TF_LITE_ENSURE_EQ(context, bias->type, data_type); + } + TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, bias->dims->data[0], filter->dims->data[0]); + } + + int channels_out = filter->dims->data[0]; + int width = input->dims->data[2]; + int height = input->dims->data[1]; + int filter_width = filter->dims->data[2]; + int filter_height = filter->dims->data[1]; + int batches = input->dims->data[0]; + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto computeOutSize = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int outWidth = computeOutSize(width, filter_width, params->stride_width); + int outHeight = computeOutSize(height, filter_height, params->stride_height); + + data->padding.height = + ComputePadding(params->stride_height, height, filter_height, outHeight); + data->padding.width = + ComputePadding(params->stride_width, width, filter_width, outWidth); + + TF_LITE_ENSURE(context, hasBias); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = batches; + output_size->data[1] = outHeight; + output_size->data[2] = outWidth; + output_size->data[3] = channels_out; + auto output_status = context->ResizeTensor(context, output, output_size); + + if (output_status != kTfLiteOk) return output_status; + + // We don't always need to allocate im2col. It is only used in some versions + // of the optimized Conv. This test just mimics something that happens inside + // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). + data->need_im2col = + (params->stride_width != 1 || params->stride_height != 1 || + filter_width != 1 || filter_height != 1); + // If we're using the optimized multithreaded EigenTensor implementation of + // convolution, it expects the filter weights to be transposed compared to + // the normal TF Lite buffer format. Typical TF Lite weights are + // [filter_count, filter_height, filter_width, input_depth], but for the float + // implementation we need them as [filter_height, filter_width, input_depth, + // filter_count]. We get to that format by transposing, and create a temporary + // buffer to store the results. + // This path is only used for float processing, so only create the buffer if + // we're running with that data type. + data->need_hwcn_weights = (data_type == kTfLiteFloat32); + + int temporaries_count = 0; + if (data->need_im2col) { + data->im2col_index = temporaries_count; + ++temporaries_count; + } + if (data->need_hwcn_weights) { + data->hwcn_weights_index = temporaries_count; + ++temporaries_count; + } + + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(temporaries_count); + + if (data->need_im2col) { + node->temporaries->data[data->im2col_index] = data->im2col_id; + + TfLiteIntArray* im2col_size = TfLiteIntArrayCreate(4); + + int input_depth = input->dims->data[3]; + im2col_size->data[0] = output_size->data[0]; + im2col_size->data[1] = output_size->data[1]; + im2col_size->data[2] = output_size->data[2]; + im2col_size->data[3] = input_depth * filter_height * filter_width; + + TfLiteTensor* im2col = + &context->tensors[node->temporaries->data[data->im2col_index]]; + im2col->type = data_type; + im2col->allocation_type = kTfLiteArenaRw; + auto im2col_status = context->ResizeTensor(context, im2col, im2col_size); + if (im2col_status != kTfLiteOk) return im2col_status; + } + + if (data->need_hwcn_weights) { + node->temporaries->data[data->hwcn_weights_index] = data->hwcn_weights_id; + TfLiteIntArray* hwcn_weights_size = TfLiteIntArrayCreate(2); + + // Because we're treating the filter weights as a matrix when we do the + // transpose, we allocate the buffer with a two-dimensional shape, where one + // dimension is the number of elements in each filter, and the second is the + // total number of filters. + int input_depth = input->dims->data[3]; + hwcn_weights_size->data[0] = (filter_height * filter_width * input_depth); + hwcn_weights_size->data[1] = channels_out; + + TfLiteTensor* hwcn_weights = + &context->tensors[node->temporaries->data[data->hwcn_weights_index]]; + hwcn_weights->type = data_type; + hwcn_weights->allocation_type = kTfLiteDynamic; + // Make sure we release any previous allocations before we reallocate. + // TODO(petewarden): Persistent arenas would be a better fit for this, but + // they aren't fully implemented yet. + if (hwcn_weights->data.raw) { + free(hwcn_weights->data.raw); + hwcn_weights->data.raw = nullptr; + } + auto hwcn_weights_status = + context->ResizeTensor(context, hwcn_weights, hwcn_weights_size); + if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status; + hwcn_weights->data.raw = static_cast(malloc(hwcn_weights->bytes)); + + // TODO(petewarden): If Resize() is called when the size hasn't actually + // changed, this will do extra redundant work. + data->have_weights_been_transposed = false; + } + + return kTfLiteOk; +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* filter, TfLiteTensor* bias, + TfLiteTensor* im2col, TfLiteTensor* hwcn_weights, + TfLiteTensor* output) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + auto input_offset = -input->params.zero_point; + auto filter_offset = -filter->params.zero_point; + auto output_offset = output->params.zero_point; + + if (kernel_type == kReference) { + reference_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + output_offset, data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col), gemm_context); + } else { + optimized_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + output_offset, data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col), gemm_context); + } +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col, + TfLiteTensor* hwcn_weights, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); + + const float* filter_data; + if (data->need_hwcn_weights) { + filter_data = GetTensorData(hwcn_weights); + } else { + filter_data = GetTensorData(filter); + } + + if (kernel_type == kReference) { + reference_ops::Conv( + GetTensorData(input), GetTensorDims(input), filter_data, + GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + } else { + multithreaded_ops::Conv( + GetTensorData(input), GetTensorDims(input), filter_data, + GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, params->padding, output_activation_min, + output_activation_max, GetTensorData(output), + GetTensorDims(output), GetTensorData(im2col), + GetTensorDims(im2col)); + } +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; + bool hasBias = node->inputs->size == 3; + TfLiteTensor* bias = + hasBias ? &context->tensors[node->inputs->data[2]] : nullptr; + TfLiteTensor* im2col = + data->need_im2col + ? &context->tensors[node->temporaries->data[data->im2col_index]] + : nullptr; + TfLiteTensor* hwcn_weights = + data->need_hwcn_weights + ? &context->tensors[node->temporaries->data[data->hwcn_weights_index]] + : nullptr; + + if (data->need_hwcn_weights && !data->have_weights_been_transposed) { + TransposeFloatTensor(filter, hwcn_weights); + data->have_weights_been_transposed = true; + } + + // TODO(aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/outtypes are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, data, input, filter, bias, + im2col, hwcn_weights, output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, data, input, filter, + bias, im2col, hwcn_weights, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace conv + +TfLiteRegistration* Register_CONVOLUTION_REF() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONVOLUTION_NEON_OPT() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONV_2D() { +#ifdef USE_NEON + return Register_CONVOLUTION_NEON_OPT(); +#else + return Register_CONVOLUTION_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d0a81c3135625c07a3566f5f9a8e5401f0d4db7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -0,0 +1,440 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseConvolutionOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BaseConvolutionOpModel( + const TensorData& input, const TensorData& filter, + const TensorData& output, int stride_width = 2, int stride_height = 2, + enum Padding padding = Padding_VALID, + enum ActivationFunctionType activation = ActivationFunctionType_NONE) { + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[0]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, + CreateConv2DOptions(builder_, padding, stride_width, + stride_height, activation) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class ConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(ConvolutionOpTest, SimpleTestFloat32) { + ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + })); +} + +TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { + ConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 6, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, + /*stride_width=*/3, /*stride_height=*/1); + m.SetInput({ + 3, 2, 1, -1, -2, -3, // + 4, 3, 2, -2, -3, -4, // + 5, 4, 3, -3, -4, -5, // + }); + m.SetFilter({ + 1, 2, // + 3, 4, // + }); + m.SetBias({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 30, -24, // + 40, -34, // + })); +} + +TEST(ConvolutionOpTest, HandCalculatedFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // No bias for this test. + m.SetBias({0}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121 + // This means we should end up with this matrix: + // | 105 | 150 | 183 | 95 | + // | 235 | 312 | 357 | 178 | + // | 187 | 234 | 261 | 121 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357, + 178, 187, 234, 261, 121})); +} + +TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // Bias is | 10 |. + m.SetBias({10}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)+10=115 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)+10=160 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)+10=193 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)+10=105 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)+10=245 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)+10=322 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)+10=367 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)+10=188 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)+10=197 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)+10=244 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)+10=271 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)+10=131 + // This means we should end up with this matrix: + // | 115 | 160 | 193 | 105 | + // | 245 | 322 | 367 | 188 | + // | 197 | 244 | 271 | 131 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({115, 160, 193, 105, 245, 322, + 367, 188, 197, 244, 271, 131})); +} + +TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding, + ActivationFunctionType_RELU); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // Bias is | -200 |. + m.SetBias({-200}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)-200=-95 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)-200=-50 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)-200=-17 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)-200=-105 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)-200=35 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)-200=112 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)-200=157 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)-200=-22 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)-200=-13 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)-200=34 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)-200=61 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)-200=-79 + // All negative values are gated to zero by the Relu activation function. + // This means we should end up with this matrix: + // | 0 | 0 | 0 | 0 | + // | 35 | 112 | 157 | 0 | + // | 0 | 34 | 61 | 0 | + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0})); +} + +TEST(ConvolutionOpTest, HandCalculatedValidFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_VALID; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // No bias for this test. + m.SetBias({0}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with no accesses outside + // the input because we're using the 'VALID' padding mode, giving a 2x1 + // output. + // The calculations behind the expected output are: + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // This means we should end up with this matrix: + // | 312 | 357 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({312, 357})); +} + +class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this tests we set the input and output scales so that the results +// match exactly the 'non-quantized' version. +TEST(ConvolutionOpTest, SimpleTestQuantized) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 145, 129, 132, // + 145, 129, 132, // + 144, 131, 130, // + 164, 131, 130, // + })); +} + +TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, + {TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}, + /*stride_width=*/3, /*stride_height=*/1); + m.SetInput({ + 3, 2, 1, -1, -2, -3, // + 4, 3, 2, -2, -3, -4, // + 5, 4, 3, -3, -4, -5, // + }); + m.SetFilter({ + 1, 2, // + 3, 4, // + }); + m.SetBias({-1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 30, -24, // + 40, -34, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 157, 103, // + 167, 93, // + })); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..15dbfe08c82befcf001b9ed9a053528b5606053e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -0,0 +1,289 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace depthwise_conv { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +// This file has three implementation of DepthwiseConv. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + // TODO(ahentz): use could use GetOptionalInputTensor() here, but we need to + // decide whether we are OK with optional tensors being completely absent, as + // opposed to having -1 as their index. + bool hasBias = NumInputs(node) == 3; + + TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + TfLiteTensor* bias = nullptr; + + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 4); + + // The parameter 'depth_multiplier' is redundant, so we check here to make + // sure it is consistent with the given dimensions. + TF_LITE_ENSURE_EQ(context, + params->depth_multiplier * SizeOfDimension(input, 3), + SizeOfDimension(filter, 3)); + + const TfLiteType data_type = input->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, output->type, data_type); + TF_LITE_ENSURE_EQ(context, filter->type, data_type); + + if (hasBias) { + bias = GetInput(context, node, kBiasTensor); + if (data_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); + } else { + TF_LITE_ENSURE_EQ(context, bias->type, data_type); + } + TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(filter, 3), + SizeOfDimension(bias, 0)); + } + + int channels_out = SizeOfDimension(filter, 3); + int width = SizeOfDimension(input, 2); + int height = SizeOfDimension(input, 1); + int filter_width = SizeOfDimension(filter, 2); + int filter_height = SizeOfDimension(filter, 1); + int batches = SizeOfDimension(input, 0); + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto compute_out_size = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int out_width = compute_out_size(width, filter_width, params->stride_width); + int out_height = + compute_out_size(height, filter_height, params->stride_height); + + data->padding.height = + ComputePadding(params->stride_height, height, filter_height, out_height); + data->padding.width = + ComputePadding(params->stride_width, width, filter_width, out_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); + outputSize->data[0] = batches; + outputSize->data[1] = out_height; + outputSize->data[2] = out_width; + outputSize->data[3] = channels_out; + return context->ResizeTensor(context, output, outputSize); +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, TfLiteTensor* bias, + TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); + + void (*depthwise_conv)(const float*, const Dims<4>&, const float*, + const Dims<4>&, const float*, const Dims<4>&, int, int, + int, int, int, float, float, float*, const Dims<4>&); + if (kernel_type == kReference) { + depthwise_conv = &reference_ops::DepthwiseConv; + } else { + depthwise_conv = &optimized_ops::DepthwiseConv; + } + + depthwise_conv( + GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + params->depth_multiplier, output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output)); +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + auto input_offset = -input->params.zero_point; + auto filter_offset = -filter->params.zero_point; + auto output_offset = output->params.zero_point; + + void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*, + const Dims<4>&, int32, const int32*, const Dims<4>&, + int, int, int, int, int, int32, int32, int, int32, + int32, uint8*, const Dims<4>&); + if (kernel_type == kReference) { + depthwise_conv = &reference_ops::DepthwiseConv; + } else { + depthwise_conv = &optimized_ops::DepthwiseConv; + } + + depthwise_conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + params->depth_multiplier, output_offset, data->output_multiplier, + data->output_shift, data->output_activation_min, + data->output_activation_max, GetTensorData(output), + GetTensorDims(output)); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + TfLiteTensor* bias = + (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; + + // TODO(aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, data, input, filter, bias, + output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, data, input, filter, + bias, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace depthwise_conv + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { +#ifdef USE_NEON + return Register_DEPTHWISE_CONVOLUTION_NEON_OPT(); +#else + return Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1439c8bce14ad127ed68dc54991aed8b8bb39383 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseDepthwiseConvolutionOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BaseDepthwiseConvolutionOpModel(const TensorData& input, + const TensorData& filter, + const TensorData& output) { + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[3]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + int input_depth = GetShape(input_)[3]; + int output_depth = GetShape(filter_)[3]; + int depth_mul = output_depth / input_depth; + + SetBuiltinOp( + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOptions_DepthwiseConv2DOptions, + CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul, + ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel { + public: + using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(DepthwiseConvolutionOpTest, SimpleTest) { + DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 71, -34, 99, -20, // + 91, -26, 127, -4, // + })); +} + +class QuantizedDepthwiseConvolutionOpModel + : public BaseDepthwiseConvolutionOpModel { + public: + using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this test we set the input and output scales so that the results match +// exactly the 'non-quantized' version. +TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) { + QuantizedDepthwiseConvolutionOpModel m( + {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64}, + {TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 71, -34, 99, -20, // + 91, -26, 127, -4, // + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 198, 93, 226, 107, // + 218, 101, 254, 123, // + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e8cb396d43a58f94b08eb8dd8b05d16fd74fd2f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Ops that looks up items from matrix. +// +// Input: +// Tensor[0]: Row number to lookup, dim.size == 1, int32 +// Tensor[1]: 2-dimensional matrix of multi-dimensional items +// dim.size >= 2, any data type. +// first dimension is row, second dimension is column. +// +// Output: +// Output.dim[0] == Tensor[0].dim[0], num of lookups +// Output.dim[1] == Tensor[1].dim[1], num of items per row +// Each item in output is a raw bytes copy of corresponding item in input. +// When indices are out of bound, the ops will not succeed. +// + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace embedding_lookup { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* lookup = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + + TfLiteTensor* value = GetInput(context, node, 1); + TF_LITE_ENSURE(context, NumDimensions(value) >= 2); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); + + outputSize->data[0] = SizeOfDimension(lookup, 0); + outputSize->data[1] = SizeOfDimension(value, 1); + for (int i = 2; i < NumDimensions(value); i++) { + outputSize->data[i] = SizeOfDimension(value, i); + } + return context->ResizeTensor(context, output, outputSize); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* lookup = GetInput(context, node, 0); + TfLiteTensor* value = GetInput(context, node, 1); + + const int row_size = SizeOfDimension(value, 0); + const int row_bytes = value->bytes / row_size; + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = lookup->data.i32[i]; + if (idx >= row_size || idx < 0) { + context->ReportError(context, "Embedding Lookup: index out of bounds."); + return kTfLiteError; + } else { + memcpy(output->data.raw + i * row_bytes, + value->data.raw + idx * row_bytes, row_bytes); + } + } + + return kTfLiteOk; +} + +} // namespace embedding_lookup + +TfLiteRegistration* Register_EMBEDDING_LOOKUP() { + static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare, + embedding_lookup::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c770e7f71efe83eace3640c47e03e0c7ab19e20 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Op that looks up items from a sparse tensor in an embedding matrix. +// The sparse lookup tensor is represented by three individual tensors: lookup, +// indices, and dense_shape. The representation assume that the corresponding +// dense tensor would satisfy: +// * dense.shape = dense_shape +// * dense[tuple(indices[i])] = lookup[i] +// +// By convention, indices should be sorted. +// +// Options: +// combiner: The reduction op (SUM, MEAN, SQRTN). +// * SUM computes the weighted sum of the embedding results. +// * MEAN is the weighted sum divided by the total weight. +// * SQRTN is the weighted sum divided by the square root of the sum of the +// squares of the weights. +// +// Input: +// Tensor[0]: Ids to lookup, dim.size == 1, int32. +// Tensor[1]: Indices, int32. +// Tensor[2]: Dense shape, int32. +// Tensor[3]: Weights to use for aggregation, float. +// Tensor[4]: Params, a matrix of multi-dimensional items, +// dim.size >= 2, float. +// +// Output: +// A (dense) tensor representing the combined embeddings for the sparse ids. +// For each row in the sparse tensor represented by (lookup, indices, shape) +// the op looks up the embeddings for all ids in that row, multiplies them by +// the corresponding weight, and combines these embeddings as specified in the +// last dimension. +// +// Output.dim = [l0, ... , ln-1, e1, ..., em] +// Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em] +// +// For instance, if params is a 10x20 matrix and ids, weights are: +// +// [0, 0]: id 1, weight 2.0 +// [0, 1]: id 3, weight 0.5 +// [1, 0]: id 0, weight 1.0 +// [2, 3]: id 1, weight 3.0 +// +// with combiner=MEAN, then the output will be a (3, 20) tensor where: +// +// output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) +// output[1, :] = (params[0, :] * 1.0) / 1.0 +// output[2, :] = (params[1, :] * 3.0) / 3.0 +// +// When indices are out of bound, the op will not succeed. + +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 5); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* ids = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1); + TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32); + + TfLiteTensor* indices = GetInput(context, node, 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2); + TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32); + + TfLiteTensor* shape = GetInput(context, node, 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1); + TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32); + + TfLiteTensor* weights = GetInput(context, node, 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1); + TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32); + + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + SizeOfDimension(ids, 0)); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + SizeOfDimension(weights, 0)); + + TfLiteTensor* value = GetInput(context, node, 4); + TF_LITE_ENSURE(context, NumDimensions(value) >= 2); + + // Mark the output as a dynamic tensor. + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + output->allocation_type = kTfLiteDynamic; + + return kTfLiteOk; +} + +void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements, + float current_total_weight, + float current_squares_weight, int embedding_size, + float* output) { + if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) { + float multiplier = 1.0; + switch (combiner) { + case kTfLiteCombinerTypeMean: + multiplier = current_total_weight; + break; + case kTfLiteCombinerTypeSqrtn: + multiplier = std::sqrt(current_squares_weight); + break; + default: + break; + } + for (int k = 0; k < embedding_size; k++) { + output[k] /= multiplier; + } + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* ids = GetInput(context, node, 0); + TfLiteTensor* indices = GetInput(context, node, 1); + TfLiteTensor* dense_shape = GetInput(context, node, 2); + TfLiteTensor* weights = GetInput(context, node, 3); + TfLiteTensor* value = GetInput(context, node, 4); + + const int lookup_rank = SizeOfDimension(indices, 1); + const int embedding_rank = NumDimensions(value); + const int num_lookups = SizeOfDimension(ids, 0); + const int num_rows = SizeOfDimension(value, 0); + + // The last dimension gets replaced by the embedding. + const int output_rank = (lookup_rank - 1) + (embedding_rank - 1); + + // Make sure that the actual dense shape of the sparse tensor represented by + // (loopkup, indices, dense_shape) is consistent. + TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank); + + // Resize output tensor. + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); + int k = 0; + int embedding_size = 1; + int lookup_size = 1; + for (int i = 0; i < lookup_rank - 1; i++, k++) { + const int dim = dense_shape->data.i32[i]; + lookup_size *= dim; + output_shape->data[k] = dim; + } + for (int i = 1; i < embedding_rank; i++, k++) { + const int dim = SizeOfDimension(value, i); + embedding_size *= dim; + output_shape->data[k] = dim; + } + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape)); + const int output_size = lookup_size * embedding_size; + TfLiteTensorRealloc(output_size * sizeof(float), output); + + tensor_utils::ZeroVector(output->data.f, output_size); + + // Keep track of the current bucket for aggregation/combination. + int current_output_offset = 0; + float current_total_weight = 0.0; + float current_squares_weight = 0.0; + int num_elements = 0; + + for (int i = 0; i < num_lookups; i++) { + int idx = ids->data.i32[i]; + if (idx >= num_rows || idx < 0) { + context->ReportError(context, + "Embedding Lookup Sparse: index out of bounds."); + return kTfLiteError; + } + + // Check where we need to aggregate. + const int example_indices_offset = i * lookup_rank; + int output_bucket = 0; + int stride = 1; + for (int k = (lookup_rank - 1) - 1; k >= 0; k--) { + output_bucket += indices->data.i32[example_indices_offset + k] * stride; + stride *= dense_shape->data.i32[k]; + } + const int output_offset = output_bucket * embedding_size; + + // If we are in a new aggregation bucket and the combiner is not the sum, + // go back and finalize the result of the previous bucket. + if (output_offset != current_output_offset) { + FinalizeAggregation(params->combiner, num_elements, current_total_weight, + current_squares_weight, embedding_size, + &output->data.f[current_output_offset]); + + // Track next bucket. + num_elements = 0; + current_total_weight = 0.0; + current_squares_weight = 0.0; + current_output_offset = output_offset; + } + + // Add element to aggregation. + ++num_elements; + const int example_embedding_offset = idx * embedding_size; + const float w = weights->data.f[i]; + current_squares_weight += w * w; + current_total_weight += w; + for (int k = 0; k < embedding_size; k++) { + output->data.f[current_output_offset + k] += + (value->data.f[example_embedding_offset + k] * w); + } + } + + // Finalize last bucket. + FinalizeAggregation(params->combiner, num_elements, current_total_weight, + current_squares_weight, embedding_size, + &output->data.f[current_output_offset]); + + return kTfLiteOk; +} + +} // namespace + +TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..dcdc5fffad9ceac1a9d23a4e91637a9ff92a8dda --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc @@ -0,0 +1,164 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite sparse lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class EmbeddingLookupSparseOpModel : public SingleOpModel { + public: + EmbeddingLookupSparseOpModel(CombinerType type, + std::initializer_list lookup_shape, + std::initializer_list indices_shape, + std::initializer_list dense_shape_shape, + std::initializer_list value_shape) { + lookup_ = AddInput(TensorType_INT32); + indices_ = AddInput(TensorType_INT32); + dense_shape_ = AddInput(TensorType_INT32); + weights_ = AddInput(TensorType_FLOAT32); + value_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + BuiltinOptions_EmbeddingLookupSparseOptions, + CreateEmbeddingLookupSparseOptions(builder_, type).Union()); + BuildInterpreter({lookup_shape, indices_shape, dense_shape_shape, + lookup_shape, value_shape}); + } + + void SetInput(std::initializer_list lookup_data, + std::initializer_list indices_data, + std::initializer_list dense_shape_data, + std::initializer_list weights_data) { + PopulateTensor(lookup_, lookup_data); + PopulateTensor(indices_, indices_data); + PopulateTensor(dense_shape_, dense_shape_data); + PopulateTensor(weights_, weights_data); + } + + void Set3DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + int columns = tensor->dims->data[1]; + int features = tensor->dims->data[2]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + for (int k = 0; k < features; k++) { + tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + } + } + } + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int lookup_; + int weights_; + int indices_; + int dense_shape_; + int value_; + int output_; +}; + +TEST(EmbeddingLookupOpTest, SimpleTest) { + EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 2}, {2}, {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 6.00, 6.06, 6.60, 6.66, 7.20, 7.26, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, SimpleTestMean) { + EmbeddingLookupSparseOpModel m(CombinerType_MEAN, {3}, {3, 2}, {2}, + {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) { + EmbeddingLookupSparseOpModel m(CombinerType_SQRTN, {3}, {3, 2}, {2}, + {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f), + 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f), + 7.20f / std::sqrt(20.0f), + 7.26f / + std::sqrt( + 20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, Indices3DTest) { + EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 3}, {3}, {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 0, 2, 0, 0, 2, 0, 1}, {3, 2, 2}, + {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, 0.00, 0.00, 0.00, + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 6.00, 6.06, 6.60, + 6.66, 7.20, 7.26, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b501878f196216a61568bfa36e6615f4dd07478 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class EmbeddingLookupOpModel : public SingleOpModel { + public: + EmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape) { + input_ = AddInput(TensorType_INT32); + weight_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({index_shape, weight_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void Set3DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(weight_); + int rows = tensor->dims->data[0]; + int columns = tensor->dims->data[1]; + int features = tensor->dims->data[2]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + for (int k = 0; k < features; k++) { + tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + } + } + } + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int weight_; + int output_; +}; + +// TODO(ahentz): write more tests that exercise the details of the op, such as +// lookup errors and variable input shapes. +TEST(EmbeddingLookupOpTest, SimpleTest) { + EmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.PopulateTensor(0, {1, 0, 2}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc new file mode 100644 index 0000000000000000000000000000000000000000..a77fe94e499078bc2f0660e8e49fd557ed0f625d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -0,0 +1,307 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace fully_connected { + +// This file has four implementations of FullyConnected +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, + kPie, // Used by the PIE team +}; + +struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + gemm_support::IncrementUsageCounter(context); + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 3); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + int input_size = 1; + for (int i = 0; i < input->dims->size; i++) { + input_size *= input->dims->data[i]; + } + + const int batch_size = input_size / filter->dims->data[1]; + const int num_units = filter->dims->data[0]; + + TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]); + if (bias) { + TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); + } + + TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + TfLiteType data_type = input->type; + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); + return kTfLiteOk; +} + +TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + int total_input_size = 1; + for (int i = 0; i < input->dims->size; i++) { + total_input_size *= input->dims->data[i]; + } + + int input_size = filter->dims->data[1]; + const int batch_size = total_input_size / filter->dims->data[1]; + const int num_units = filter->dims->data[0]; + + // Output = bias if bias tensor exists. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // Compute output += weight * input + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + filter->data.f, num_units, input_size, input->data.f, batch_size, + output->data.f, /*result_stride=*/1); + + // Apply activation function + tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units, + params->activation, output->data.f); + + return kTfLiteOk; +} + +#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \ + if (params->activation == kTfLiteActNone) { \ + macro_name(target_namespace, kNone); \ + } \ + if (params->activation == kTfLiteActRelu) { \ + macro_name(target_namespace, kRelu); \ + } \ + if (params->activation == kTfLiteActRelu6) { \ + macro_name(target_namespace, kRelu6); \ + } + +template +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + int32_t input_offset = -input->params.zero_point; + int32_t filter_offset = -filter->params.zero_point; + int32_t output_offset = output->params.zero_point; +#define TF_LITE_FULLY_CONNECTED(type) \ + type::FullyConnected( \ + GetTensorData(input), GetTensorDims(input), input_offset, \ + GetTensorData(filter), GetTensorDims(filter), filter_offset, \ + GetTensorData(bias), GetTensorDims(bias), output_offset, \ + data->output_multiplier, data->output_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output), gemm_context) + if (kernel_type == kReference) { + TF_LITE_FULLY_CONNECTED(reference_ops); + } else if (kernel_type == kPie) { + // TODO(ahentz): we don't have a quantized version of the PIE kernels, so + // we just defer to the MINI ones. + TF_LITE_FULLY_CONNECTED(optimized_ops); + } else { + TF_LITE_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_FULLY_CONNECTED + + return kTfLiteOk; +} + +template +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_FULLY_CONNECTED(type) \ + type::FullyConnected(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(filter), GetTensorDims(filter), \ + GetTensorData(bias), GetTensorDims(bias), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_FULLY_CONNECTED(reference_ops); + } else if (kernel_type == kPie) { + return EvalPie(context, node, params, data, input, filter, bias, output); + } else { + TF_LITE_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_FULLY_CONNECTED + + return kTfLiteOk; +} + +#undef TF_LITE_MACRO_DISPATCH + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + return EvalFloat(context, node, params, data, input, filter, + bias, output); + case kTfLiteUInt8: + return EvalQuantized(context, node, params, data, input, + filter, bias, output); + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace fully_connected + +TfLiteRegistration* Register_FULLY_CONNECTED_REF() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_NEON_OPT() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_PIE() { + static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free, + fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED() { + // TODO(ahentz): We don't have a dedicated quantized version of the PIE + // kernel. For now, the quantized version just defer to the corresponding + // optimized MINI kernel. At some point we will allow different libraries to + // be built with different kernels, but for now we have to pick one here. + return Register_FULLY_CONNECTED_PIE(); +#ifdef USE_NEON + return Register_FULLY_CONNECTED_NEON_OPT(); +#else + return Register_FULLY_CONNECTED_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a0f766c4f4580d7679275c0b63aa200410fcb5ad --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -0,0 +1,376 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite FULLY_CONNECTED op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +static float fully_connected_input[] = { + 0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653, + 0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390, + 0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314, + 0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550, + 0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112, + 0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999, + 0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142, + 0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494, + 0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081, + 0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552, + 0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320, + 0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458, + 0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115, + 0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771, + 0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582, + 0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962, + 0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202, + 0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691, + 0.921330, 0.139902}; + +static float fully_connected_golden_output[] = { + 0, 0.0732134, 0, 0, 0, 0.280859, + 0, 0.128927, 0, 0.0777251, 0, 0.270268, + 0.271435, 0.0173503, 0.335465, 0.235562, + + 0, 0.0745866, 0, 0.051611, 0, 0.253876, + 0, 0.0814873, 0, 0.104104, 0, 0.248529, + 0.264194, 0, 0.302973, 0.166252, + + 0, 0.0170409, 0, 0.0509851, 0, 0.212834, + 0, 0.0208326, 0, 0.129932, 0.203978, 0.103428, + 0.298051, 0, 0.332233, 0.00445903, + + 0, 0.125246, 0, 0.0735336, 0, 0.0910256, + 0, 0, 0, 0.18933, 0.378111, 0.0712443, + 0.277298, 0.0123414, 0.267454, 0, + + 0, 0.14687, 0, 0.155495, 0.0300215, 0.147256, + 0, 0, 0, 0.156412, 0.434914, 0.0461529, + 0.246508, 0, 0.363138, 0, + + 0, 0, 0, 0.0212949, 0, 0.301708, + 0, 0.35497, 0, 0.406223, 0.0260211, 0.049195, + 0.197161, 0, 0.37316, 0, + + 0, 0.221783, 0, 0, 0.0116515, 0.281945, + 0, 0, 0, 0, 0.285626, 0.181773, + 0.296401, 0.170452, 0.367135, 0.142597, + + 0, 0, 0, 0, 0, 0.418886, + 0, 0.291063, 0, 0.227541, 0.0424759, 0.27589, + 0.398286, 0.177146, 0.40359, 0.121452, + + 0, 0.0834884, 0, 0, 0, 0.287441, + 0, 0.0046838, 0, 0.0122087, 0, 0.217376, + 0.140183, 0.0948412, 0.436677, 0.0589876, + + 0, 0.0289969, 0, 0.0921397, 0, 0.396802, + 0, 0.0126157, 0, 0.0968433, 0, 0.172271, + 0.173295, 0.0664741, 0.53645, 0.00915603, + + 0, 0, 0, 0, 0, 0.147942, + 0, 0.263795, 0, 0.39782, 0, 0.382435, + 0.561072, 0.0579847, 0.145712, 0.13508, + + 0, 0, 0, 0.16382, 0, 0.322294, + 0, 0.163798, 0, 0.405211, 0.367953, 0.076852, + 0.342473, 0.0834118, 0.377537, 0, + + 0, 0.206, 0, 0, 0, 0.375769, + 0, 0, 0, 0, 0, 0.125165, + 0, 0.105591, 0.52055, 0.0536445, + + 0, 0.259261, 0, 0, 0, 0.247707, + 0, 0, 0, 0, 0, 0.215862, + 0.149153, 0.224678, 0.359519, 0.129419, + + 0, 0.17611, 0, 0.280895, 0, 0.576484, + 0, 0.000418848, 0, 0, 0, 0.151112, + 0.211902, 0, 0.566341, 0.106305, + + 0, 0.0246284, 0, 0, 0, 0.196267, + 0, 0.0248624, 0, 0.265635, 0, 0.436199, + 0.408079, 0.134514, 0.328489, 0.411368}; + +class BaseFullyConnectedOpModel : public SingleOpModel { + public: + // TODO(ahentz): test different activation types too. + BaseFullyConnectedOpModel(int units, int batches, const TensorData& input, + const TensorData& output = {TensorType_FLOAT32}) + : batches_(batches), units_(units) { + int total_input_size = 1; + for (int i = 0; i < input.shape.size(); ++i) { + total_input_size *= input.shape[i]; + } + input_size_ = total_input_size / batches_; + + input_ = AddInput(input); + weights_ = + AddInput({input.type, {units_, input_size_}, input.min, input.max}); + + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {units_}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(weights_); + TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + SetBuiltinOp( + BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + .Union()); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_; + int bias_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel { + public: + using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { + public: + using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + void SetWeights(std::initializer_list data) { + QuantizeAndPopulate(weights_, data); + } + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// TODO(ahentz): add more small tests like this one, focused on making sure the +// calculations are correct. +TEST(FullyConnectedOpTest, SimpleTest) { + FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); +} + +TEST(FullyConnectedOpTest, SimpleTestQuantized) { + QuantizedFullyConnectedOpModel m( + 3, 2, + /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64}, + /*output=*/{TensorType_UINT8, {}, -127, 128}); + + // input_product_scale < output_scale was not true. + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); +} + +TEST(FullyConnectedOpTest, SimpleTest4DInput) { + // Note that it is not required that the first dimension be the number of + // batches. All we care is that the input can be evenly distributed in + // batches. In this case, we need the input to have multiples of '2'. + FloatFullyConnectedOpModel m(/*units=*/3, + /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // first batch + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // second batch + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 24, 25, 26, // first batch + 58, 59, 60, // second batch + })); +} + +TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) { + QuantizedFullyConnectedOpModel m( + 3, 2, + /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64}, + /*output=*/{TensorType_UINT8, {}, -127, 128}); + + // input_product_scale < output_scale was not true. + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); +} + +// TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard +// to debug errors and doesn't necessarily test all the important details. +TEST(FullyConnectedOpTest, BlackBoxTest) { + FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}}); + m.SetWeights( + {0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636, + -0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504, + -0.275581, 0.059388, -0.118497, -0.079224, 0.109758, 0.008307, + -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650, + 0.266455, 0.051517, -0.123448, 0.322464, 0.043282, -0.173782, + -0.190381, 0.002013, 0.096086, 0.131157, 0.031164, 0.100638, + -0.312191, -0.080923, -0.101318, -0.116614, 0.142238, 0.086540, + -0.139154, 0.174268, -0.073161, 0.080072, 0.006874, 0.229382, + -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824, + -0.025021, 0.074460, -0.252595, -0.161750, -0.136403, 0.008308, + 0.005710, 0.096600, 0.289839, 0.218816, -0.304651, -0.070958, + 0.054598, 0.147113, -0.139112, -0.072798, -0.163335, -0.167863, + -0.128762, -0.035780, 0.117262, 0.017177, 0.263335, -0.176612, + 0.262961, -0.093654, -0.339283, 0.333071, 0.180827, 0.287583, + 0.066350, -0.197947, -0.114449, -0.236035, 0.103532, -0.034284, + 0.093299, -0.145361, 0.054001, 0.250570, 0.157010, -0.143480, + -0.139061, -0.048873, 0.067557, 0.139038, 0.324106, 0.227041, + 0.037793, -0.225747, -0.241619, 0.357835, 0.135762, -0.306764, + -0.125982, 0.091916, 0.266587, 0.030135, 0.265148, 0.141627, + 0.020120, 0.083815, -0.124556, -0.100124, -0.048159, 0.181172, + 0.302309, -0.041084, 0.146334, -0.061511, -0.232605, 0.281324, + 0.145408, -0.221897}); + m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860, + 0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804, + 0.048478, -0.032270, 0.175688, -0.085662}); + + const int input_sequence_size = sizeof(fully_connected_input) / + sizeof(float) / + (m.input_size() * m.num_batches()); + for (int i = 0; i < input_sequence_size; i++) { + // TODO(ahentz): This is what the original test was doing: two equal + // batches per invocation. We could instead use two different batches. + float* batch_start = fully_connected_input + i * m.input_size(); + float* batch_end = batch_start + m.input_size(); + m.SetInput(0, batch_start, batch_end); + m.SetInput(m.input_size(), batch_start, batch_end); + + m.Invoke(); + + float* golden_start = fully_connected_golden_output + i * m.num_units(); + float* golden_end = golden_start + m.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb2b0aacf7ecc3ed5dbde5ccce7a46dcda0a93b3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gemm_support.cc @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/gemm_support.h" + +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace gemm_support { + +struct RefCountedGemmContext { + gemmlowp::GemmContext* gemm_context_ = nullptr; + int num_references_ = 0; +}; + +void IncrementUsageCounter(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + ptr = new RefCountedGemmContext; + ptr->gemm_context_ = new gemmlowp::GemmContext(); + ptr->num_references_ = 0; + context->gemm_context = ptr; + } + ptr->num_references_++; +} + +void DecrementUsageCounter(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to DecrementUsageCounter() not preceded by " + "IncrementUsageCounter()"); + } + if (--ptr->num_references_ == 0) { + delete ptr->gemm_context_; + delete ptr; + context->gemm_context = nullptr; + } +} + +gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to GetFromContext() not preceded by IncrementUsageCounter()"); + } + return ptr->gemm_context_; +} + +void SetMaxNumThreads(TfLiteContext* context, int num_threads) { + IncrementUsageCounter(context); + GetFromContext(context)->set_max_num_threads(num_threads); + DecrementUsageCounter(context); +} + +} // namespace gemm_support +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h new file mode 100644 index 0000000000000000000000000000000000000000..b531959ffb143c774ee715743480b03ebfbdc114 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { +namespace gemm_support { + +// Returns the GemmContext stored in 'context', allowing multiple ops to +// share a single object, as long as they share a TfLiteContext. The caller +// must ensure that this is called between IncrementUsageCounter() and +// DecrementUsageCounter(). For example, in the implementation of an op: +// void* Init(TfLiteContext* context, const char*, size_t) { +// gemm_support::IncrementUsageCounter(context); +// return nullptr; +// } +// void Free(TfLiteContext* context, void*) { +// gemm_support::DecrementUsageCounter(context); +// } +// TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { +// auto* gemm_context = gemm_support::GetFromContext(context); +// } +gemmlowp::GemmContext* GetFromContext(TfLiteContext* context); + +// Let the framework know that the GemmContext stored in 'context' will be used +// by an op. If necessary a new GemmContext is created and placed in 'context'. +void IncrementUsageCounter(TfLiteContext* context); + +// Let the framework know that the op stopped using the GemmContext stored in +// 'context'. If there are no more usages the GemmContext will be deleted. +void DecrementUsageCounter(TfLiteContext* context); + +// Set the maximum number threads available for gemmlowp operations. +void SetMaxNumThreads(TfLiteContext* context, int num_threads); + +} // namespace gemm_support +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b82601d119b2e4946db6e3577300168c7e710b6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc @@ -0,0 +1,155 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Op that looks up items from hashtable. +// +// Input: +// Tensor[0]: Hash key to lookup, dim.size == 1, int32 +// Tensor[1]: Key of hashtable, dim.size == 1, int32 +// *MUST* be sorted in ascending order. +// Tensor[2]: Value of hashtable, dim.size >= 1 +// Tensor[1].Dim[0] == Tensor[2].Dim[0] +// +// Output: +// Output[0].dim[0] == Tensor[0].dim[0], num of lookups +// Each item in output is a raw bytes copy of corresponding item in input. +// When key does not exist in hashtable, the returned bytes are all 0s. +// +// Output[1].dim = { Tensor[0].dim[0] }, num of lookups +// Each item indicates whether the corresponding lookup has a returned value. +// 0 for missing key, 1 for found key. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +int greater(const void* a, const void* b) { + return *static_cast(a) - *static_cast(b); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); + + TfLiteTensor* lookup = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + + TfLiteTensor* key = GetInput(context, node, 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1); + TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32); + + TfLiteTensor* value = GetInput(context, node, 2); + TF_LITE_ENSURE(context, NumDimensions(value) >= 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0), + SizeOfDimension(value, 0)); + if (value->type == kTfLiteString) { + TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1); + } + + TfLiteTensor* hits = GetOutput(context, node, 1); + TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8); + TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1); + hitSize->data[0] = SizeOfDimension(lookup, 0); + + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, value->type, output->type); + + TfLiteStatus status = kTfLiteOk; + if (output->type != kTfLiteString) { + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); + outputSize->data[0] = SizeOfDimension(lookup, 0); + for (int i = 1; i < NumDimensions(value); i++) { + outputSize->data[i] = SizeOfDimension(value, i); + } + status = context->ResizeTensor(context, output, outputSize); + } + if (context->ResizeTensor(context, hits, hitSize) == kTfLiteError) { + status = kTfLiteError; + } + return status; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* hits = GetOutput(context, node, 1); + TfLiteTensor* lookup = GetInput(context, node, 0); + TfLiteTensor* key = GetInput(context, node, 1); + TfLiteTensor* value = GetInput(context, node, 2); + + const int num_rows = SizeOfDimension(value, 0); + const int row_bytes = value->bytes / num_rows; + void* pointer = nullptr; + DynamicBuffer buf; + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = -1; + pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows, + sizeof(int32_t), greater); + if (pointer != nullptr) { + idx = (reinterpret_cast(pointer) - (key->data.raw)) / + sizeof(int32_t); + } + + if (idx >= num_rows || idx < 0) { + if (output->type == kTfLiteString) { + buf.AddString(nullptr, 0); + } else { + memset(output->data.raw + i * row_bytes, 0, row_bytes); + } + hits->data.uint8[i] = 0; + } else { + if (output->type == kTfLiteString) { + buf.AddString(GetString(value, idx)); + } else { + memcpy(output->data.raw + i * row_bytes, + value->data.raw + idx * row_bytes, row_bytes); + } + hits->data.uint8[i] = 1; + } + } + if (output->type == kTfLiteString) { + buf.WriteToTensor(output); + } + + return kTfLiteOk; +} +} // namespace + +TfLiteRegistration* Register_HASHTABLE_LOOKUP() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb6038f9009a3865661e7b4f075c3033166d0f91 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc @@ -0,0 +1,176 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class HashtableLookupOpModel : public SingleOpModel { + public: + HashtableLookupOpModel(std::initializer_list lookup_shape, + std::initializer_list key_shape, + std::initializer_list value_shape, + TensorType type) { + lookup_ = AddInput(TensorType_INT32); + key_ = AddInput(TensorType_INT32); + value_ = AddInput(type); + output_ = AddOutput(type); + hit_ = AddOutput(TensorType_UINT8); + SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({lookup_shape, key_shape, value_shape}); + } + + void SetLookup(std::initializer_list data) { + PopulateTensor(lookup_, data); + } + + void SetHashtableKey(std::initializer_list data) { + PopulateTensor(key_, data); + } + + void SetHashtableValue(const std::vector& content) { + PopulateStringTensor(value_, content); + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + for (int i = 0; i < rows; i++) { + tensor->data.f[i] = function(i); + } + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + int features = tensor->dims->data[1]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < features; j++) { + tensor->data.f[i * features + j] = function(i, j); + } + } + } + + std::vector GetStringOutput() { + TfLiteTensor* output = interpreter_->tensor(output_); + int num = GetStringCount(output); + std::vector result(num); + for (int i = 0; i < num; i++) { + auto ref = GetString(output, i); + result[i] = string(ref.str, ref.len); + } + return result; + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetHit() { return ExtractVector(hit_); } + + private: + int lookup_; + int key_; + int value_; + int output_; + int hit_; +}; + +// TODO(yichengfan): write more tests that exercise the details of the op, +// such as lookup errors and variable input shapes. +TEST(HashtableLookupOpTest, Test2DInput) { + HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 2.0, 2.1, // 2-nd item + 0, 0, // Not found + 0.0, 0.1, // 0-th item + 1.0, 1.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, 0, 1, 1, + })); +} + +TEST(HashtableLookupOpTest, Test1DInput) { + HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i) { return i * i / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.4, // 2-nd item + 0, // Not found + 0.0, // 0-th item + 0.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} + +TEST(HashtableLookupOpTest, TestString) { + HashtableLookupOpModel m({4}, {3}, {3}, TensorType_STRING); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue({"Hello", "", "Hi"}); + + m.Invoke(); + + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({ + "Hi", // 2-nd item + "", // Not found + "Hello", // 0-th item + "", // 1-st item + })); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..288534099b9e090ce0c223a401b4152ca6ffb61f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -0,0 +1,359 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +tflite_deps_intel = [ + "@arm_neon_2_x86_sse", +] + +NEON_FLAGS_IF_APPLICABLE = select({ + ":arm": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + ":armeabi-v7a": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + ":armv7a": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + "//conditions:default": [ + "-O3", + ], +}) + +cc_library( + name = "types", + srcs = [], + hdrs = [ + "compatibility.h", + "types.h", + ], +) + +config_setting( + name = "arm", + values = { + "cpu": "arm", + }, +) + +config_setting( + name = "arm64-v8a", + values = { + "cpu": "arm64-v8a", + }, +) + +config_setting( + name = "armv7a", + values = { + "cpu": "armv7a", + }, +) + +config_setting( + name = "armeabi-v7a", + values = { + "cpu": "armeabi-v7a", + }, +) + +config_setting( + name = "haswell", + values = { + "cpu": "haswell", + }, +) + +config_setting( + name = "ios_x86_64", + values = { + "cpu": "ios_x86_64", + }, +) + +config_setting( + name = "ios_armv7", + values = { + "cpu": "ios_armv7", + }, +) + +config_setting( + name = "ios_arm64", + values = { + "cpu": "ios_arm64", + }, +) + +config_setting( + name = "k8", + values = { + "cpu": "k8", + }, +) + +config_setting( + name = "x86", + values = { + "cpu": "x86", + }, +) + +config_setting( + name = "x86_64", + values = { + "cpu": "x86_64", + }, +) + +config_setting( + name = "darwin", + values = { + "cpu": "darwin", + }, +) + +cc_library( + name = "optimized_base", + srcs = [], + hdrs = [ + "common.h", + "optimized/depthwiseconv_float.h", + "optimized/depthwiseconv_uint8.h", + "optimized/optimized_ops.h", + ], + copts = tflite_copts(), + deps = [ + ":types", + ":round", + "//third_party/eigen3", + "@gemmlowp//:gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + "//conditions:default": [], + }), +) + +cc_library( + name = "optimized", + hdrs = [ + "optimized/eigen_spatial_convolutions.h", + "optimized/eigen_tensor_reduced_instantiations_oss.h", + "optimized/multithreaded_conv.h", + "tensor.h", + ], + deps = [ + ":optimized_base", + ":types", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:context", + "//third_party/eigen3", + ], +) + +cc_test( + name = "tensor_test", + srcs = ["tensor_test.cc"], + deps = [ + ":reference", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "round", + srcs = [], + hdrs = ["round.h"], +) + +cc_library( + name = "quantization_util", + srcs = ["quantization_util.cc"], + hdrs = [ + "compatibility.h", + "quantization_util.h", + ], + deps = [":round"], +) + +cc_test( + name = "quantization_util_test", + srcs = ["quantization_util_test.cc"], + deps = [ + ":quantization_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "reference_base", + srcs = [], + hdrs = [ + "common.h", + "reference/depthwiseconv_float.h", + "reference/depthwiseconv_uint8.h", + "reference/reference_ops.h", + ], + deps = [ + ":round", + ":types", + "//third_party/eigen3", + "@gemmlowp//:gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + "//conditions:default": [], + }), +) + +cc_library( + name = "reference", + hdrs = ["tensor.h"], + deps = [ + ":types", + "//tensorflow/contrib/lite:context", + ], +) + +cc_library( + name = "portable_tensor_utils", + srcs = [ + "reference/portable_tensor_utils.cc", + ], + hdrs = [ + "reference/portable_tensor_utils.h", + ], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:activation_functor", + "//tensorflow/contrib/lite/kernels:op_macros", + ], +) + +cc_library( + name = "neon_tensor_utils", + srcs = [ + "optimized/neon_tensor_utils.cc", + ], + hdrs = [ + "optimized/neon_tensor_utils.h", + "optimized/tensor_utils_impl.h", + ], + copts = NEON_FLAGS_IF_APPLICABLE, + deps = [ + ":cpu_check", + ":portable_tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:activation_functor", + ], +) + +cc_library( + name = "tensor_utils", + srcs = [ + "tensor_utils.cc", + ], + hdrs = [ + "optimized/tensor_utils_impl.h", + "reference/portable_tensor_utils.h", + "tensor_utils.h", + ], + copts = NEON_FLAGS_IF_APPLICABLE, + deps = [ + "//tensorflow/contrib/lite/kernels:activation_functor", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":arm": [ + ":neon_tensor_utils", + ], + ":arm64-v8a": [ + ":neon_tensor_utils", + ], + ":armeabi-v7a": [ + ":neon_tensor_utils", + ], + ":armv7a": [ + ":neon_tensor_utils", + ], + ":ios_armv7": [ + ":neon_tensor_utils", + ], + ":ios_arm64": [ + ":neon_tensor_utils", + ], + "//conditions:default": [ + ":portable_tensor_utils", + ], + }), +) + +cc_test( + name = "tensor_utils_test", + srcs = ["tensor_utils_test.cc"], + copts = NEON_FLAGS_IF_APPLICABLE, + linkopts = select({ + "//tensorflow:android": [ + "-fPIE -pie", + ], + "//conditions:default": [], + }), + linkstatic = 1, + deps = [ + ":tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "cpu_check", + hdrs = [ + "optimized/cpu_check.h", + ], + deps = [ + ] + select( + { + "//tensorflow:android": [ + "@androidndk//:cpufeatures", + ], + "//conditions:default": [], + }, + ), +) + +exports_files(["optimized/eigen_tensor_reduced_instantiations_oss.h"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h new file mode 100644 index 0000000000000000000000000000000000000000..28f19a250629aec4d03aa71df57d31d8a5014e9f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ + +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif +#endif + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#include +#endif + +#if defined __GNUC__ && defined __SSE4_1__ +#define USE_NEON + +#define OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#pragma GCC diagnostic ignored "-Wattributes" + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnarrowing" +#pragma GCC diagnostic ignored "-Wsequence-point" + +#include "NEON_2_SSE.h" + +#pragma GCC diagnostic pop +#endif +#endif + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +inline void GetActivationMinMax(FusedActivationFunctionType ac, + float* output_activation_min, + float* output_activation_max) { + switch (ac) { + case FusedActivationFunctionType::kNone: + *output_activation_min = std::numeric_limits::lowest(); + *output_activation_max = std::numeric_limits::max(); + break; + case FusedActivationFunctionType::kRelu: + *output_activation_min = 0.f; + *output_activation_max = std::numeric_limits::max(); + break; + case FusedActivationFunctionType::kRelu1: + *output_activation_min = -1.f; + *output_activation_max = 1.f; + break; + case FusedActivationFunctionType::kRelu6: + *output_activation_min = 0.f; + *output_activation_max = 6.f; + break; + } +} + +inline float ActivationFunctionWithMinMax(float x, float output_activation_min, + float output_activation_max) { + return std::min(std::max(x, output_activation_min), output_activation_max); +} + +// Legacy function, left for compatibility only. +template +float ActivationFunction(float x) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + return ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); +} + +inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( + int32 x, int32 quantized_multiplier, int right_shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + return RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); +} + +inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( + int32 x, int32 quantized_multiplier, int left_shift) { + using gemmlowp::SaturatingRoundingDoublingHighMul; + return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), + quantized_multiplier); +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h new file mode 100644 index 0000000000000000000000000000000000000000..796a03566a4bf971294dd2375f590dfd20d600f7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ + +#include +#include +#include + +#ifndef TFLITE_DCHECK +#define TFLITE_DCHECK(condition) (condition) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_EQ +#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_GE +#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_GT +#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_LE +#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_LT +#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : assert(false) +#endif + +// TODO(ahentz): Clean up: We should stick to the DCHECK versions. +#ifndef TFLITE_CHECK +#define TFLITE_CHECK(condition) (condition) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_EQ +#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_GE +#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_GT +#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_LE +#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_LT +#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : abort() +#endif + +// TODO(ahentz): Clean up. +using uint8 = std::uint8_t; +using int16 = std::int16_t; +using uint16 = std::uint16_t; +using int32 = std::int32_t; +using uint32 = std::uint32_t; + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h new file mode 100644 index 0000000000000000000000000000000000000000..dea46cc12065ed34cf681916a46a55bd7a86f463 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ + +namespace tflite { + +#ifdef __ANDROID__ +#include "ndk/sources/android/cpufeatures/cpu-features.h" + +// Runtime check for Neon support on Android. +inline bool TestCPUFeatureNeon() { +#ifdef __aarch64__ + // ARM-64 always has NEON support. + return true; +#else + static bool kUseAndroidNeon = + (android_getCpuFamily() == ANDROID_CPU_FAMILY_ARM && + android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_ARMv7 && + android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_NEON); + return kUseAndroidNeon; +#endif // __aarch64__ +} + +#elif __ARM_NEON + +inline bool TestCPUFeatureNeon() { + return true; +} + +#else + +inline bool TestCPUFeatureNeon() { + return false; +} + +#endif + +} // namespace tflite + +// NEON_OR_PORTABLE(SomeFunc, arcs) calls NeonSomeFunc(args) if Neon is both +// enabled at build time and detected at runtime, or PortableSomeFunc(args) +// otherwise. +#ifdef __ARM_ARCH_5TE__ +// Neon isn't available at all on ARMv5. +#define NEON_OR_PORTABLE(funcname, ...) Portable##funcname(__VA_ARGS__) +#else +#define NEON_OR_PORTABLE(funcname, ...) \ + TestCPUFeatureNeon() ? Neon##funcname(__VA_ARGS__) \ + : Portable##funcname(__VA_ARGS__) +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h new file mode 100644 index 0000000000000000000000000000000000000000..da34c8aef94b1c69e661bd33fcb518e73034c4bd --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -0,0 +1,1060 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Implementation of float DepthwiseConv + +template +struct FloatDepthwiseConvKernel {}; + +#ifdef USE_NEON + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + float32x4_t input[4]; + for (int i = 0; i < 4; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_f32(acc[0], input[0], filter[0]); + acc[1] = vmlaq_f32(acc[1], input[1], filter[1]); + acc[2] = vmlaq_f32(acc[2], input[2], filter[0]); + acc[3] = vmlaq_f32(acc[3], input[3], filter[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + const float32x2_t filters = vld1_f32(filter_ptr); + const float32x4_t filters_dup2 = vcombine_f32(filters, filters); + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the inputs + float32x4_t input[4]; + for (int i = 0; i < 4; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 4; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + const float32x4_t input = vld1q_f32(input_ptr); + input_ptr += 4; + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filters_dup2); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle 1 output pixel at a time + for (; outp < num_output_pixels; outp++) { + // Load the inputs + const float32x2_t input = vld1_f32(input_ptr); + input_ptr += 2; + // Load the accumulators from acc_buffer + float32x2_t acc = vld1_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmla_f32(acc, input, filters); + // Store the accumulators back to acc_buffer + vst1_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 16 input channels at a time. + for (; ic <= input_depth - 16; ic += 16) { + // Load the filters + float32x4_t filter_0 = vld1q_f32(local_filter_ptr + 4 * 0); + float32x4_t filter_1 = vld1q_f32(local_filter_ptr + 4 * 1); + float32x4_t filter_2 = vld1q_f32(local_filter_ptr + 4 * 2); + float32x4_t filter_3 = vld1q_f32(local_filter_ptr + 4 * 3); + local_filter_ptr += 16; + // Load the inputs + float32x4_t input_0 = vld1q_f32(local_input_ptr + 4 * 0); + float32x4_t input_1 = vld1q_f32(local_input_ptr + 4 * 1); + float32x4_t input_2 = vld1q_f32(local_input_ptr + 4 * 2); + float32x4_t input_3 = vld1q_f32(local_input_ptr + 4 * 3); + local_input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0); + float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1); + float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2); + float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3); + // Multiply-accumulate + acc_0 = vmlaq_f32(acc_0, input_0, filter_0); + acc_1 = vmlaq_f32(acc_1, input_1, filter_1); + acc_2 = vmlaq_f32(acc_2, input_2, filter_2); + acc_3 = vmlaq_f32(acc_3, input_3, filter_3); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3); + acc_buffer_ptr += 16; + } + // Handle 4 input channels at a time. + for (; ic <= input_depth - 4; ic += 4) { + // Load the filters + float32x4_t filter; + filter = vld1q_f32(local_filter_ptr); + local_filter_ptr += 4; + // Load the inputs + float32x4_t input; + input = vld1q_f32(local_input_ptr); + local_input_ptr += 4; + // Load the accumulators from acc_buffer + float32x4_t acc; + acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const float input_val = *local_input_ptr++; + const float filter_val = *local_filter_ptr++; + *acc_buffer_ptr++ += filter_val * input_val; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 2 input channels at a time. + for (; ic <= input_depth - 2; ic += 2) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + const float32x2_t input = vld1_f32(local_input_ptr); + local_input_ptr += 2; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0); + acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0); + acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1); + acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 8; + // Load the inputs + const float input_val = *local_input_ptr++; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + input_ptr += input_ptr_increment; + } + } +}; + +// Note this implementation is very slow for input_depths < 8 +// (e.g. comparable to reference implementation) see, specializations for +// input_depth=3 below. +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + float32x4x2_t input_dup2[2]; + for (int i = 0; i < 2; i++) { + const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i); + input_dup2[i] = vzipq_f32(input, input); + } + local_input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]); + acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]); + acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]); + acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 input channels at a time. + for (; ic <= input_depth - 4; ic += 4) { + // Load the filters + float32x2_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1_f32(local_filter_ptr + 2 * i); + } + local_filter_ptr += 8; + // Load the inputs + const float32x4_t input = vld1q_f32(local_input_ptr); + local_input_ptr += 4; + // Load the accumulators from acc_buffer + float32x2_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate + acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0); + acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1); + acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0); + acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 input channels at a time. + for (; ic <= input_depth - 2; ic += 2) { + // Load the filters + const float32x4_t filter = vld1q_f32(local_filter_ptr); + local_filter_ptr += 4; + // Load the inputs + const float32x2_t input = vld1_f32(local_input_ptr); + local_input_ptr += 2; + // Load the accumulators from acc_buffer + float32x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate + acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0); + acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 4; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the inputs + const float input_val = *local_input_ptr++; + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc_buffer_ptr[i] += local_filter_ptr[i] * input_val; + } + local_filter_ptr += 2; + acc_buffer_ptr += 2; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x2_t filter[3]; + for (int i = 0; i < 3; i++) { + filter[i] = vld1_f32(filter_ptr + 2 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float32x2_t input01 = vld1_f32(input_ptr); + const float32x2_t input2 = vld1_dup_f32(input_ptr + 2); + // Load the accumulators from acc_buffer + float32x2_t acc[3]; + for (int i = 0; i < 3; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate for each input channel there 2 outputs + acc[0] = vmla_lane_f32(acc[0], filter[0], input01, 0); + acc[1] = vmla_lane_f32(acc[1], filter[1], input01, 1); + acc[2] = vmla_lane_f32(acc[2], filter[2], input2, 0); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 3; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 6; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[3]; + for (int i = 0; i < 3; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // NOTE: we only want 3 values, so we read it as two ops where + // the second op just duplicates the lane + const float32x2_t input01 = vld1_f32(input_ptr); + const float32x2_t input2 = vld1_dup_f32(input_ptr + 2); + // Load the accumulators from acc_buffer + float32x4_t acc[3]; + for (int i = 0; i < 3; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate all outputs. + acc[0] = vmlaq_lane_f32(acc[0], filter[0], input01, 0); + acc[1] = vmlaq_lane_f32(acc[1], filter[1], input01, 1); + acc[2] = vmlaq_lane_f32(acc[2], filter[2], input2, 0); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 3; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 12; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + const float input_val = *input_ptr; + input_ptr += input_ptr_increment; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0); + float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1); + float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2); + float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3); + float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4); + float32x4_t filter_5 = vld1q_f32(filter_ptr + 4 * 5); + float32x4_t filter_6 = vld1q_f32(filter_ptr + 4 * 6); + float32x4_t filter_7 = vld1q_f32(filter_ptr + 4 * 7); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + const float input_val = *input_ptr; + input_ptr += input_ptr_increment; + // Load the accumulators from acc_buffer + float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0); + float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1); + float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2); + float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3); + float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4); + float32x4_t acc_5 = vld1q_f32(acc_buffer_ptr + 4 * 5); + float32x4_t acc_6 = vld1q_f32(acc_buffer_ptr + 4 * 6); + float32x4_t acc_7 = vld1q_f32(acc_buffer_ptr + 4 * 7); + // Multiply-accumulate + acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val); + acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val); + acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val); + acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val); + acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val); + acc_5 = vmlaq_n_f32(acc_5, filter_5, input_val); + acc_6 = vmlaq_n_f32(acc_6, filter_6, input_val); + acc_7 = vmlaq_n_f32(acc_7, filter_7, input_val); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4); + vst1q_f32(acc_buffer_ptr + 4 * 5, acc_5); + vst1q_f32(acc_buffer_ptr + 4 * 6, acc_6); + vst1q_f32(acc_buffer_ptr + 4 * 7, acc_7); + acc_buffer_ptr += 32; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + for (int ic = 0; ic < input_depth; ic++) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + const float input_val = *local_input_ptr++; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 4; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + float32x2_t filter = vld1_f32(filter_ptr); + float32x4_t filter_x4 = vcombine_f32(filter, filter); + int outp = 0; + + // Handle two output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + float32x2_t input_1 = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + float32x2_t input_2 = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + float32x4_t input = vcombine_f32(input_1, input_2); + + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter_x4); + + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the inputs + float32x2_t input = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + + // Load the accumulators from acc_buffer + float32x2_t acc = vld1_f32(acc_buffer_ptr); + + // Multiply-accumulate + acc = vmla_f32(acc, input, filter); + + // Store the accumulators back to acc_buffer + vst1_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + float32x4_t filter = vld1q_f32(filter_ptr); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input = vld1q_f32(input_ptr); + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + input_ptr += input_ptr_increment; + } + } +}; +#endif + +// Accumulates the effect of one row of the filter, on a segment of one row +// of the output, accessing the corresponding one row of the input. +template +void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width, + const float* input_data, int pad_width, + int depth_multiplier, int filter_width, + const float* filter_data, + int out_x_buffer_start, int out_x_buffer_end, + int output_depth, float* acc_buffer) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); +#endif + // Sanity check parameters. This is important in particular to ensure + // that we keep the number of template instantiations minimal, so we don't + // increase binary size unnecessarily. + static_assert(kFixedDepthMultiplier || !kFixedInputDepth, ""); + static_assert(kFixedInputDepth || kAllowStrided, ""); + TFLITE_DCHECK(stride == 1 || kAllowStrided); + if (kFixedInputDepth) { + TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth); + } + if (kFixedDepthMultiplier) { + TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier); + } + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + const int input_ptr_increment = stride * input_depth; + const float* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + // For the current (filter_x, filter_y) point in the filter, + // compute the boundaries of the corresponding output row segment. + int out_x_loop_start_unclampled = 0; + int out_x_loop_end_unclampled = 0; + if (kAllowStrided) { + if (stride == 2) { + out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 1) / 2; + } else if (stride == 4) { + out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 3) / 4; + } else { + out_x_loop_start_unclampled = + (pad_width - filter_x + stride - 1) / stride; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + stride - 1) / stride; + } + } else { + out_x_loop_start_unclampled = pad_width - filter_x; + out_x_loop_end_unclampled = pad_width + input_width - filter_x; + } + // The kernel will have to iterate on the segment of the + // output row that starts at out_x_loop_start and out_x_loop_end. + const int out_x_loop_start = + std::max(out_x_buffer_start, out_x_loop_start_unclampled); + const int out_x_loop_end = + std::min(out_x_buffer_end, out_x_loop_end_unclampled); + + float* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const float* input_ptr = input_data + in_x_origin * input_depth; + const int num_output_pixels = out_x_loop_end - out_x_loop_start; + FloatDepthwiseConvKernel::Run(num_output_pixels, + input_depth, + depth_multiplier, + input_ptr, + input_ptr_increment, + filter_base_ptr, + acc_buffer_ptr); + filter_base_ptr += output_depth; + } +} + +// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized. +inline void FloatDepthwiseConvAccumRowGeneric( + int stride, int input_depth, int input_width, const float* input_data, + int pad_width, int depth_multiplier, int filter_width, + const float* filter_data, int out_x_buffer_start, int out_x_buffer_end, + int output_depth, float* acc_buffer) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)"); +#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + LOG(FATAL) + << "\n\n" + << "*****************************************************************\n" + << "* This tfmini inference code was about to use the slow generic\n" + << "* fallback implementation for a DepthwiseConv op, and we want you\n" + << "* to be aware of that so that you will know why you get terrible\n" + << "* performance.\n" + << "*\n" + << "* If you would like to carry on with the slow code, compile\n" + << "* with this preprocessor token defined:\n" + << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" + << "*\n" + << "* The right thing to do, if you care about performance, is to add\n" + << "* a new DepthwiseConv kernel to tfmini to cover your case.\n" + << "* The relevant parameters defining your case are:\n" + << "* stride = " << stride << "\n" + << "* input_depth = " << input_depth << "\n" + << "* depth_multiplier = " << depth_multiplier << "\n" + << "*\n" + << "* Please do not hesitate to contact benoitjacob@ with this\n" + << "* information.\n" + << "*****************************************************************\n"; +#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + const float* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int out_x_loop_start = std::max( + out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride); + const int out_x_loop_end = + std::min(out_x_buffer_end, + (pad_width + input_width - filter_x + stride - 1) / stride); + + float* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const float* input_ptr = input_data + in_x_origin * input_depth; + const int input_ptr_increment = (stride - 1) * input_depth; + for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { + const float* filter_ptr = filter_base_ptr; + for (int ic = 0; ic < input_depth; ++ic) { + const float input_val = *input_ptr++; + for (int m = 0; m < depth_multiplier; m++) { + const float filter_val = *filter_ptr++; + *acc_buffer_ptr++ += filter_val * input_val; + } + } + input_ptr += input_ptr_increment; + } + filter_base_ptr += output_depth; + } +} + +// Initializes the accumulator buffer with bias values. +inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, + const float* bias_data, + float* acc_buffer) { + // TODO(benoitjacob): This might need optimized specializations + // for small output_depth values, if that ever becomes an important + // case (like it was for some quantized DepthwiseConv cases). + for (int i = 0; i < num_output_pixels; i++) { + memcpy(acc_buffer + i * output_depth, bias_data, + sizeof(acc_buffer[0]) * output_depth); + } +} + +inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConv"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + static const int kAccBufferMaxSize = 2048; + float acc_buffer[kAccBufferMaxSize]; + TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth); + const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth; + const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth; + TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, + kAccBufferActualSize); + TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize); + TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1); + + // row_accum_func will point to the core accumulation function to be used + // for this DepthwiseConv op. + using row_accum_func_t = decltype(&FloatDepthwiseConvAccumRowGeneric); + row_accum_func_t row_accum_func = nullptr; + +#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \ + FIXED_DEPTH_MULTIPLIER) \ + if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \ + (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \ + depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \ + row_accum_func = \ + FloatDepthwiseConvAccumRow; \ + } + +#ifdef USE_NEON + // We go over our list of kernels by decreasing order of preference + // for the cases where multiple kernels could apply. + + // Start with the fastest kernels: AllowStrided=false, fixed input depth. + + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1) + + // Next come the strided kernels: AllowStrided=true, fixed input depth. + // They are a bit less efficient, but allow stride!=1. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 4) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1) + + // Finally, the kernels allowing a variable input depth, + // these are the least efficient but most general kernels. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 16) + +#endif // USE_NEON + +#undef TFMINI_USE_DEPTHWISECONV_KERNEL + + // No matching fast kernel found, use slow fallback. + if (!row_accum_func) { + row_accum_func = FloatDepthwiseConvAccumRowGeneric; + } + + // Now that we have determined row_accum_func, we can start work. + float* output_ptr = output_data; + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + for (int out_x_buffer_start = 0; out_x_buffer_start < output_width; + out_x_buffer_start += kOutputPixelsInAccBuffer) { + const int out_x_buffer_end = std::min( + output_width, out_x_buffer_start + kOutputPixelsInAccBuffer); + // We call a 'pixel' a group of activation that share all but the + // 'depth'/'channel' coordinate. num_output_pixels is the number of + // output pixels that we will accumulate in this loop iteration. + const int num_output_pixels = out_x_buffer_end - out_x_buffer_start; + // Initialize our local accumulator with the bias values, so we don't + // have to add them later. + DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data, + acc_buffer); + // Accumulation loop. Most of the time should be spent in here. + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + const int in_y = in_y_origin + filter_y; + row_accum_func(stride_width, input_depth, input_width, + input_data + in_y * input_dims.strides[2] + + b * input_dims.strides[3], + pad_width, depth_multiplier, filter_width, + filter_data + filter_y * filter_dims.strides[2], + out_x_buffer_start, out_x_buffer_end, output_depth, + acc_buffer); + } + // Finished accumulating. Now store to destination. + const int num_output_values = output_depth * num_output_pixels; + int i = 0; +// TODO(benoitjacob) optimized code goes here +#ifdef USE_NEON + // Handle 16 values at a time + for (; i <= num_output_values - 16; i += 16) { + float32x4_t acc[4]; + for (int k = 0; k < 4; k++) { + acc[k] = vld1q_f32(acc_buffer + i + 4 * k); + } + for (int k = 0; k < 4; k++) { + acc[k] = vmaxq_f32( + vdupq_n_f32(output_activation_min), + vminq_f32(vdupq_n_f32(output_activation_max), acc[k])); + } + for (int k = 0; k < 4; k++) { + vst1q_f32(output_ptr + 4 * k, acc[k]); + } + output_ptr += 16; + } + // Handle 4 values at a time + for (; i <= num_output_values - 4; i += 4) { + float32x4_t acc = vld1q_f32(acc_buffer + i); + + acc = vmaxq_f32(vdupq_n_f32(output_activation_min), + vminq_f32(vdupq_n_f32(output_activation_max), acc)); + + vst1q_f32(output_ptr, acc); + output_ptr += 4; + } +#endif + // Handle leftover values, one by one. This is very slow. + for (; i < num_output_values; i++) { + float acc = acc_buffer[i]; + acc = std::max(output_activation_min, + std::min(output_activation_max, acc)); + + *output_ptr++ = acc; + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride_width, stride_height, pad_width, pad_height, + depth_multiplier, output_activation_min, output_activation_max, + output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + float* output_data, const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, + depth_multiplier, output_data, output_dims); +} + +} // namespace optimized_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h new file mode 100644 index 0000000000000000000000000000000000000000..051ed2a2c44a04f0473dfd26637e53865a5a51ac --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -0,0 +1,1916 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ + +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Implementation of quantized DepthwiseConv + +template +struct QuantizedDepthwiseConvKernel {}; + +#ifdef USE_NEON +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8x2_t filter_u8; + filter_u8.val[0] = vld1_u8(filter_ptr); + filter_u8.val[1] = vld1_u8(filter_ptr + 8); + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])), + vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += input_ptr_increment; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[0].val[i] = vmlal_s16(acc[0].val[i], vget_low_s16(filter[i]), + vget_low_s16(input_dup2.val[i])); + acc[1].val[i] = vmlal_s16(acc[1].val[i], vget_high_s16(filter[i]), + vget_high_s16(input_dup2.val[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += 16; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input[0])); + acc[1] = + vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input[0])); + acc[2] = vmlal_s16(acc[2], vget_low_s16(filter), vget_low_s16(input[1])); + acc[3] = + vmlal_s16(acc[3], vget_high_s16(filter), vget_high_s16(input[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x4_t acc[2]; + acc[0] = vld1q_s32(acc_buffer_ptr); + acc[1] = vld1q_s32(acc_buffer_ptr + 4); + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input)); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc[0]); + vst1q_s32(acc_buffer_ptr + 4, acc[1]); + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(filter), + vget_low_s16(input_dup2.val[i])); + acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(filter), + vget_high_s16(input_dup2.val[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x4x2_t input_dup2 = vzip_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), input_dup2.val[0]); + acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), input_dup2.val[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + int outp = 0; + // Handle two output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate. + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1); + acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), input, 2); + acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), input, 2); + acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), input, 3); + acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), input, 3); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1); + + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x4_t input_dup2 = vzip_s16(input, input).val[0]; + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input_dup2); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += 16; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input[1])); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer. + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input)); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input)); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer. + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x2_t acc = vld1_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer. + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x2_t acc = vld1_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + const uint32 input = *input_ptr++ + input_offset; + + // Multiply-accumulate + acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], filter, vget_low_s16(input), 0); + acc[1] = vmlal_lane_s16(acc[1], filter, vget_low_s16(input), 1); + acc[2] = vmlal_lane_s16(acc[2], filter, vget_low_s16(input), 2); + acc[3] = vmlal_lane_s16(acc[3], filter, vget_low_s16(input), 3); + acc[4] = vmlal_lane_s16(acc[4], filter, vget_high_s16(input), 0); + acc[5] = vmlal_lane_s16(acc[5], filter, vget_high_s16(input), 1); + acc[6] = vmlal_lane_s16(acc[6], filter, vget_high_s16(input), 2); + acc[7] = vmlal_lane_s16(acc[7], filter, vget_high_s16(input), 3); + + // Store the accumulators back to acc_buffer + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], filter, input, 0); + acc[1] = vmlal_lane_s16(acc[1], filter, input, 1); + acc[2] = vmlal_lane_s16(acc[2], filter, input, 2); + acc[3] = vmlal_lane_s16(acc[3], filter, input, 3); + + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + const uint32 input = *input_ptr++ + input_offset; + + // Multiply-accumulate + acc = vmlal_n_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t input_u8 = vld1_u8(input_ptr + 8 * i); + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + input[i] = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + } + input_ptr += 16; + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = + vmlal_s16(acc[2 * i + 0], filter, vget_low_s16(input[i])); + acc[2 * i + 1] = + vmlal_s16(acc[2 * i + 1], filter, vget_high_s16(input[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), + vget_low_s16(input), 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), + vget_low_s16(input), 1); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), + vget_low_s16(input), 2); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), + vget_low_s16(input), 3); + acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), + vget_high_s16(input), 0); + acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), + vget_high_s16(input), 1); + acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), + vget_high_s16(input), 2); + acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), + vget_high_s16(input), 3); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 1); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 2); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 3); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // We will have to duplicate bytes in a NEON register, 3-fold. + // We will do that by register-level table-look-up using VTBL instructions. + // Here we prepare the registers containing the table-lookup indices. + static const uint8 dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2}, + {2, 3, 3, 3, 4, 4, 4, 5}, + {5, 5, 6, 6, 6, 7, 7, 7}}; + uint8x8_t dup3_indices[3]; + for (int i = 0; i < 3; i++) { + dup3_indices[i] = vld1_u8(dup3_indices_array[i]); + } + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + int16x8_t filter[3]; + uint8x8x3_t filter_u8; + filter_u8.val[0] = vld1_u8(local_filter_ptr); + filter_u8.val[1] = vld1_u8(local_filter_ptr + 8); + filter_u8.val[2] = vld1_u8(local_filter_ptr + 16); + local_filter_ptr += 24; + for (int i = 0; i < 3; i++) { + const int16x8_t filter_s16 = + vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + // Load the inputs, duplicate 3-fold, add input_offset. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + + uint8x8_t input_u8_dup3[3]; + for (int i = 0; i < 3; i++) { + input_u8_dup3[i] = vtbl1_u8(input_u8, dup3_indices[i]); + } + int16x8_t input_dup3[3]; + for (int i = 0; i < 3; i++) { + const int16x8_t input_s16_dup3 = + vreinterpretq_s16_u16(vmovl_u8(input_u8_dup3[i])); + input_dup3[i] = vaddq_s16(input_s16_dup3, vdupq_n_s16(input_offset)); + } + // Load the accumulators from acc_buffer + int32x4x3_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + acc[i].val[2] = vld1q_s32(acc_buffer_ptr + 4 * i + 16); + } + // Multiply-accumulate + for (int j = 0; j < 3; j++) { + acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(input_dup3[j]), + vget_low_s16(filter[j])); + acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(input_dup3[j]), + vget_high_s16(filter[j])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + vst1q_s32(acc_buffer_ptr + 4 * i + 16, acc[i].val[2]); + } + acc_buffer_ptr += 24; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const int16 input_val = *local_input_ptr++ + input_offset; + for (int i = 0; i < 3; i++) { + const int16 filter_val = local_filter_ptr[i] + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + local_filter_ptr += 3; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + uint8x8x2_t filter_u8; + filter_u8.val[0] = vld1_u8(local_filter_ptr); + filter_u8.val[1] = vld1_u8(local_filter_ptr + 8); + local_filter_ptr += 16; + for (int i = 0; i < 2; i++) { + const int16x8_t filter_s16 = + vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + // Load the inputs, add input_offset, duplicate 2-fold. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Load the accumulators from acc_buffer. + int32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + } + // Multiply-accumulate. + for (int j = 0; j < 2; j++) { + acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(filter[j]), + vget_low_s16(input_dup2.val[j])); + acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(filter[j]), + vget_high_s16(input_dup2.val[j])); + } + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + } + acc_buffer_ptr += 16; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the inputs. + const int16 input_val = *local_input_ptr++ + input_offset; + for (int i = 0; i < 2; i++) { + const int16 filter_val = local_filter_ptr[i] + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + local_filter_ptr += 2; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 16 input channels at a time. + for (; ic <= input_depth - 16; ic += 16) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(local_filter_ptr + 8 * 0); + uint8x8_t filter_u8_1 = vld1_u8(local_filter_ptr + 8 * 1); + local_filter_ptr += 16; + int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset)); + filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset)); + // Load the inputs, add input_offset. + uint8x8_t input_u8_0 = vld1_u8(local_input_ptr + 8 * 0); + uint8x8_t input_u8_1 = vld1_u8(local_input_ptr + 8 * 1); + local_input_ptr += 16; + int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0)); + int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1)); + input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset)); + input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3); + acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), vget_low_s16(filter_0)); + acc_1 = + vmlal_s16(acc_1, vget_high_s16(input_0), vget_high_s16(filter_0)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(input_1), vget_low_s16(filter_1)); + acc_3 = + vmlal_s16(acc_3, vget_high_s16(input_1), vget_high_s16(filter_1)); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3); + acc_buffer_ptr += 16; + } + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(local_filter_ptr); + local_filter_ptr += 8; + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = + vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter)); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const int16 input_val = *local_input_ptr++ + input_offset; + const int16 filter_val = *local_filter_ptr++ + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8[2]; + for (int i = 0; i < 2; i++) { + filter_u8[i] = vld1_u8(filter_ptr + 8 * i); + } + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i])); + } + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += input_ptr_increment; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(input[i]), + vget_low_s16(filter[i])); + acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(input[i]), + vget_high_s16(filter[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter)); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8[2]; + for (int i = 0; i < 2; i++) { + filter_u8[i] = vld1_u8(filter_ptr + 8 * i); + } + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i])); + } + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = + vmlal_n_s16(acc[2 * i + 0], vget_low_s16(filter[i]), input); + acc[2 * i + 1] = + vmlal_n_s16(acc[2 * i + 1], vget_high_s16(filter[i]), input); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0); + uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1); + uint8x8_t filter_u8_2 = vld1_u8(filter_ptr + 8 * 2); + uint8x8_t filter_u8_3 = vld1_u8(filter_ptr + 8 * 3); + int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + int16x8_t filter_2 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_2)); + int16x8_t filter_3 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_3)); + filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset)); + filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset)); + filter_2 = vaddq_s16(filter_2, vdupq_n_s16(filter_offset)); + filter_3 = vaddq_s16(filter_3, vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3); + int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4); + int32x4_t acc_5 = vld1q_s32(acc_buffer_ptr + 4 * 5); + int32x4_t acc_6 = vld1q_s32(acc_buffer_ptr + 4 * 6); + int32x4_t acc_7 = vld1q_s32(acc_buffer_ptr + 4 * 7); + // Multiply-accumulate + acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input); + acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input); + acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input); + acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input); + acc_4 = vmlal_n_s16(acc_4, vget_low_s16(filter_2), input); + acc_5 = vmlal_n_s16(acc_5, vget_high_s16(filter_2), input); + acc_6 = vmlal_n_s16(acc_6, vget_low_s16(filter_3), input); + acc_7 = vmlal_n_s16(acc_7, vget_high_s16(filter_3), input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4); + vst1q_s32(acc_buffer_ptr + 4 * 5, acc_5); + vst1q_s32(acc_buffer_ptr + 4 * 6, acc_6); + vst1q_s32(acc_buffer_ptr + 4 * 7, acc_7); + acc_buffer_ptr += 32; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(filter_u8)), vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_n_s16(acc[0], vget_low_s16(filter), input); + acc[1] = vmlal_n_s16(acc[1], vget_high_s16(filter), input); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint16x4_t input_u16 = vdup_n_u16(0); + input_u16 = vset_lane_u16((reinterpret_cast(input_ptr))[0], + input_u16, 0); + input_ptr += input_ptr_increment; + input_u16 = vset_lane_u16((reinterpret_cast(input_ptr))[0], + input_u16, 1); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = vreinterpret_s16_u16( + vget_low_u16(vmovl_u8(vreinterpret_u8_u16(input_u16)))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer. + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x2_t acc = vld1_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer. + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + if (num_output_pixels <= 0) { + return; + } + + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + + // Handle one output pixel at a time until second to the last pixel. Second + // to the last because we read eight input pixels while only processing + // four. + for (; outp < num_output_pixels - 1; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + + // Handle the last output pixel. + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(filter_ptr); + uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 4); + int16x8_t filter_s16_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_s16_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + filter_s16_0 = vaddq_s16(filter_s16_0, vdupq_n_s16(filter_offset)); + filter_s16_1 = vaddq_s16(filter_s16_1, vdupq_n_s16(filter_offset)); + int16x4_t filter_0 = vget_low_s16(filter_s16_0); + int16x4_t filter_1 = vget_high_s16(filter_s16_0); + int16x4_t filter_2 = vget_high_s16(filter_s16_1); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + uint8x8_t input_u8_0 = vld1_u8(input_ptr); + uint8x8_t input_u8_1 = vld1_u8(input_ptr + 4); + input_ptr += input_ptr_increment; + int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0)); + int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1)); + input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset)); + input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset)); + + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + + // Multiply-accumulate + acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), filter_0); + acc_1 = vmlal_s16(acc_1, vget_high_s16(input_0), filter_1); + acc_2 = vmlal_s16(acc_2, vget_high_s16(input_1), filter_2); + + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + + acc_buffer_ptr += 12; + } + } +}; +#endif + +// Accumulates the effect of one row of the filter, on a segment of one row +// of the output, accessing the corresponding one row of the input. +template +void QuantizedDepthwiseConvAccumRow( + int stride, int input_depth, int input_width, const uint8* input_data, + int16 input_offset, int pad_width, int depth_multiplier, int filter_width, + const uint8* filter_data, int16 filter_offset, int out_x_buffer_start, + int out_x_buffer_end, int output_depth, int32* acc_buffer) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); +#endif + // Sanity check parameters. This is important in particular to ensure + // that we keep the number of template instantiations minimal, so we don't + // increase binary size unnecessarily. + static_assert(kFixedDepthMultiplier || !kFixedInputDepth, ""); + static_assert(kFixedInputDepth || kAllowStrided, ""); + TFLITE_DCHECK(stride == 1 || kAllowStrided); + if (kFixedInputDepth) { + TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth); + } + if (kFixedDepthMultiplier) { + TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier); + } + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + const int input_ptr_increment = stride * input_depth; + const uint8* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + // For the current (filter_x, filter_y) point in the filter, + // compute the boundaries of the corresponding output row segment. + int out_x_loop_start_unclampled = 0; + int out_x_loop_end_unclampled = 0; + if (kAllowStrided) { + if (stride == 2) { + out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 1) / 2; + } else if (stride == 4) { + out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 3) / 4; + } else { + out_x_loop_start_unclampled = + (pad_width - filter_x + stride - 1) / stride; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + stride - 1) / stride; + } + } else { + out_x_loop_start_unclampled = pad_width - filter_x; + out_x_loop_end_unclampled = pad_width + input_width - filter_x; + } + // The kernel will have to iterate on the segment of the + // output row that starts at out_x_loop_start and out_x_loop_end. + const int out_x_loop_start = + std::max(out_x_buffer_start, out_x_loop_start_unclampled); + const int out_x_loop_end = + std::min(out_x_buffer_end, out_x_loop_end_unclampled); + + int32* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const uint8* input_ptr = input_data + in_x_origin * input_depth; + const int num_output_pixels = out_x_loop_end - out_x_loop_start; + QuantizedDepthwiseConvKernel< + kAllowStrided, kFixedInputDepth, + kFixedDepthMultiplier>::Run(num_output_pixels, input_depth, + depth_multiplier, input_ptr, input_offset, + input_ptr_increment, filter_base_ptr, + filter_offset, acc_buffer_ptr); + filter_base_ptr += output_depth; + } +} + +// generic fallback of DepthwiseConvAccumRow, portable, non-templatized. +inline void QuantizedDepthwiseConvAccumRowGeneric( + int stride, int input_depth, int input_width, const uint8* input_data, + int16 input_offset, int pad_width, int depth_multiplier, int filter_width, + const uint8* filter_data, int16 filter_offset, int out_x_buffer_start, + int out_x_buffer_end, int output_depth, int32* acc_buffer) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)"); +#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + LOG(FATAL) + << "\n\n" + << "*****************************************************************\n" + << "* This tfmini inference code was about to use the slow generic\n" + << "* fallback implementation for a DepthwiseConv op, and we want you\n" + << "* to be aware of that so that you will know why you get terrible\n" + << "* performance.\n" + << "*\n" + << "* If you would like to carry on with the slow code, compile\n" + << "* with this preprocessor token defined:\n" + << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" + << "*\n" + << "* The right thing to do, if you care about performance, is to add\n" + << "* a new DepthwiseConv kernel to tfmini to cover your case.\n" + << "* The relevant parameters defining your case are:\n" + << "* stride = " << stride << "\n" + << "* input_depth = " << input_depth << "\n" + << "* depth_multiplier = " << depth_multiplier << "\n" + << "*\n" + << "* Please do not hesitate to contact benoitjacob@ with this\n" + << "* information.\n" + << "*****************************************************************\n"; +#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + const uint8* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int out_x_loop_start = std::max( + out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride); + const int out_x_loop_end = + std::min(out_x_buffer_end, + (pad_width + input_width - filter_x + stride - 1) / stride); + + int32* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const uint8* input_ptr = input_data + in_x_origin * input_depth; + const int input_ptr_increment = (stride - 1) * input_depth; + for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { + const uint8* filter_ptr = filter_base_ptr; + for (int ic = 0; ic < input_depth; ++ic) { + const int16 input_val = *input_ptr++ + input_offset; + for (int m = 0; m < depth_multiplier; m++) { + const int16 filter_val = *filter_ptr++ + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + } + input_ptr += input_ptr_increment; + } + filter_base_ptr += output_depth; + } +} + +// Initializes the accumulator buffer with bias values. +inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, + const int32* bias_data, + int32* acc_buffer) { + int i = 0; +#ifdef USE_NEON + if (output_depth == 1) { + const int32x4_t b = vdupq_n_s32(bias_data[0]); + for (; i <= num_output_pixels - 16; i += 16) { + vst1q_s32(acc_buffer + i + 0, b); + vst1q_s32(acc_buffer + i + 4, b); + vst1q_s32(acc_buffer + i + 8, b); + vst1q_s32(acc_buffer + i + 12, b); + } + for (; i <= num_output_pixels - 4; i += 4) { + vst1q_s32(acc_buffer + i, b); + } + } else if (output_depth == 2) { + int32x4_t b = vdupq_n_s32(bias_data[0]); + b = vsetq_lane_s32(bias_data[1], b, 1); + b = vsetq_lane_s32(bias_data[1], b, 3); + for (; i <= num_output_pixels - 8; i += 8) { + vst1q_s32(acc_buffer + 2 * i + 0, b); + vst1q_s32(acc_buffer + 2 * i + 4, b); + vst1q_s32(acc_buffer + 2 * i + 8, b); + vst1q_s32(acc_buffer + 2 * i + 12, b); + } + for (; i <= num_output_pixels - 2; i += 2) { + vst1q_s32(acc_buffer + 2 * i, b); + } + } else if (output_depth == 4) { + const int32x4_t b = vld1q_s32(bias_data); + for (; i <= num_output_pixels - 4; i += 4) { + vst1q_s32(acc_buffer + 4 * i + 0, b); + vst1q_s32(acc_buffer + 4 * i + 4, b); + vst1q_s32(acc_buffer + 4 * i + 8, b); + vst1q_s32(acc_buffer + 4 * i + 12, b); + } + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 4 * i, b); + } + } else if (output_depth == 8) { + const int32x4_t b0 = vld1q_s32(bias_data); + const int32x4_t b1 = vld1q_s32(bias_data + 4); + for (; i <= num_output_pixels - 2; i += 2) { + vst1q_s32(acc_buffer + 8 * i + 0, b0); + vst1q_s32(acc_buffer + 8 * i + 4, b1); + vst1q_s32(acc_buffer + 8 * i + 8, b0); + vst1q_s32(acc_buffer + 8 * i + 12, b1); + } + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 8 * i + 0, b0); + vst1q_s32(acc_buffer + 8 * i + 4, b1); + } + } else if (output_depth == 16) { + const int32x4_t b0 = vld1q_s32(bias_data); + const int32x4_t b1 = vld1q_s32(bias_data + 4); + const int32x4_t b2 = vld1q_s32(bias_data + 8); + const int32x4_t b3 = vld1q_s32(bias_data + 12); + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 16 * i + 0, b0); + vst1q_s32(acc_buffer + 16 * i + 4, b1); + vst1q_s32(acc_buffer + 16 * i + 8, b2); + vst1q_s32(acc_buffer + 16 * i + 12, b3); + } + } +#endif + for (; i < num_output_pixels; i++) { + memcpy(acc_buffer + i * output_depth, bias_data, + sizeof(acc_buffer[0]) * output_depth); + } +} + +inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + static const int kAccBufferMaxSize = 2048; + int32 acc_buffer[kAccBufferMaxSize]; + TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth); + const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth; + const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth; + TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, + kAccBufferActualSize); + TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize); + TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1); + + // row_accum_func will point to the core accumulation function to be used + // for this DepthwiseConv op. + using row_accum_func_t = decltype(&QuantizedDepthwiseConvAccumRowGeneric); + row_accum_func_t row_accum_func = nullptr; + +#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \ + FIXED_DEPTH_MULTIPLIER) \ + if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \ + (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \ + depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \ + row_accum_func = \ + QuantizedDepthwiseConvAccumRow; \ + } + +#ifdef USE_NEON + // We go over our list of kernels by decreasing order of preference + // for the cases where multiple kernels could apply. + + // Start with the fastest kernels: AllowStrided=false, fixed input depth. + + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 4) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 4) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 12, 1) + + // Next come the strided kernels: AllowStrided=true, fixed input depth. + // They are a bit less efficient, but allow stride!=1. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1) + + // Finally, the kernels allowing a variable input depth, + // these are the least efficient but most general kernels. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 3) +#endif // USE_NEON + + // No matching fast kernel found, use slow fallback. + if (!row_accum_func) { + row_accum_func = QuantizedDepthwiseConvAccumRowGeneric; + } + +#undef TFMINI_USE_DEPTHWISECONV_KERNEL + + // Now that we have determined row_accum_func, we can start work. + uint8* output_ptr = output_data; + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + for (int out_x_buffer_start = 0; out_x_buffer_start < output_width; + out_x_buffer_start += kOutputPixelsInAccBuffer) { + const int out_x_buffer_end = std::min( + output_width, out_x_buffer_start + kOutputPixelsInAccBuffer); + // We call a 'pixel' a group of activation that share all but the + // 'depth'/'channel' coordinate. num_output_pixels is the number of + // output pixels that we will accumulate in this loop iteration. + const int num_output_pixels = out_x_buffer_end - out_x_buffer_start; + // Initialize our local accumulator with the bias values, so we don't + // have to add them later. + DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data, + acc_buffer); + // Accumulation loop. Most of the time should be spent in here. + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + const int in_y = in_y_origin + filter_y; + row_accum_func( + stride_width, input_depth, input_width, + input_data + in_y * input_dims.strides[2] + + b * input_dims.strides[3], + input_offset, pad_width, depth_multiplier, filter_width, + filter_data + filter_y * filter_dims.strides[2], filter_offset, + out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer); + } + // Finished accumulating int32 values. Now need to convert them to + // the final 8bit form and store them. + gemmlowp::ScopedProfilingLabel label("downquantize+store"); + const int num_output_values = output_depth * num_output_pixels; + int i = 0; +#ifdef USE_NEON + using gemmlowp::RoundingDivideByPOT; + const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); + const int32x4_t output_activation_min_vec = + vdupq_n_s32(output_activation_min); + const int32x4_t output_activation_max_vec = + vdupq_n_s32(output_activation_max); + // Handle 16 values at once. + // This allows us to issue 4 mutually independent int32 + // multiplications (vqrdmulh), which should alleviate most of their + // high latency. + for (; i <= num_output_values - 16; i += 16) { + int32x4_t acc[4]; + for (int j = 0; j < 4; j++) { + acc[j] = vld1q_s32(acc_buffer + i + 4 * j); + } + + // Fixed-point multiplication. + for (int j = 0; j < 4; j++) { + acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); + } + for (int j = 0; j < 4; j++) { + acc[j] = RoundingDivideByPOT(acc[j], output_shift); + } + // Add the output offset. + for (int j = 0; j < 4; j++) { + acc[j] = vaddq_s32(acc[j], output_offset_vec); + } + // Apply the activation function. + for (int j = 0; j < 4; j++) { + acc[j] = vmaxq_s32(acc[j], output_activation_min_vec); + } + for (int j = 0; j < 4; j++) { + acc[j] = vminq_s32(acc[j], output_activation_max_vec); + } + // Saturating cast to uint8 and store to destination. + int16x4_t acc_s16[4]; + for (int j = 0; j < 4; j++) { + acc_s16[j] = vqmovn_s32(acc[j]); + } + const int16x8_t res_s16_0 = vcombine_s16(acc_s16[0], acc_s16[1]); + const int16x8_t res_s16_1 = vcombine_s16(acc_s16[2], acc_s16[3]); + const uint8x8_t res_u8_0 = vqmovun_s16(res_s16_0); + const uint8x8_t res_u8_1 = vqmovun_s16(res_s16_1); + vst1q_u8(output_ptr, vcombine_u8(res_u8_0, res_u8_1)); + output_ptr += 16; + } + // Handle 8 values at once. + // Not as good as 16 (now we're only issuing 2 mutually independent + // vqrdmulh instructions, so we're probably paying for their high + // latency). + for (; i <= num_output_values - 8; i += 8) { + int32x4_t acc0 = vld1q_s32(acc_buffer + i); + int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4); + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + // Rounding right shift. + acc0 = RoundingDivideByPOT(acc0, output_shift); + acc1 = RoundingDivideByPOT(acc1, output_shift); + // Add the output offset. + acc0 = vaddq_s32(acc0, output_offset_vec); + acc1 = vaddq_s32(acc1, output_offset_vec); + // Apply the activation function. + acc0 = vmaxq_s32(acc0, output_activation_min_vec); + acc1 = vmaxq_s32(acc1, output_activation_min_vec); + acc0 = vminq_s32(acc0, output_activation_max_vec); + acc1 = vminq_s32(acc1, output_activation_max_vec); + // Saturating cast to uint8 and store to destination. + const int16x4_t acc0_s16 = vqmovn_s32(acc0); + const int16x4_t acc1_s16 = vqmovn_s32(acc1); + const int16x8_t res_s16 = vcombine_s16(acc0_s16, acc1_s16); + const uint8x8_t res_u8 = vqmovun_s16(res_s16); + vst1_u8(output_ptr, res_u8); + output_ptr += 8; + } + // Handle 4 values at once. Now we're paying the full price of the + // high latency of vqrdmulh. Also, storing only 4 bytes at the end + // (without any alignment) can only be done 1 byte at a time. + // Yet, that is still worth doing to minimize the amount of leftover + // that will have to go through the very slow scalar code. + for (; i <= num_output_values - 4; i += 4) { + int32x4_t acc = vld1q_s32(acc_buffer + i); + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + // Rounding right shift. + acc = RoundingDivideByPOT(acc, output_shift); + // Add the output offset. + acc = vaddq_s32(acc, output_offset_vec); + // Apply the activation function. + acc = vmaxq_s32(acc, output_activation_min_vec); + acc = vminq_s32(acc, output_activation_max_vec); + // Saturating cast to uint8 and store to destination. + const int16x4_t acc_s16 = vqmovn_s32(acc); + const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16); + const uint8x8_t res_u8 = vqmovun_s16(res_s16); + vst1_lane_u8(output_ptr + 0, res_u8, 0); + vst1_lane_u8(output_ptr + 1, res_u8, 1); + vst1_lane_u8(output_ptr + 2, res_u8, 2); + vst1_lane_u8(output_ptr + 3, res_u8, 3); + output_ptr += 4; + } +#endif // USE_NEON + + // Handle leftover values, one by one. This is very slow. + for (; i < num_output_values; i++) { + int32 acc = acc_buffer[i]; + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + *output_ptr++ = static_cast(acc); + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, + stride_height, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, input_offset, filter_data, + filter_dims, filter_offset, bias_data, bias_dims, stride, + stride, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +} // namespace optimized_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h new file mode 100644 index 0000000000000000000000000000000000000000..8004c24a9914e216974539930853d0aadf61e324 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h @@ -0,0 +1,231 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/kernels/eigen_spatial_convolutions.h. +// TODO(petewarden) - move this to a common location in Eigen itself. + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ + +#define EIGEN_USE_CUSTOM_THREAD_POOL +#define EIGEN_USE_THREADS + +// NOTE: Eigen is slightly different internally and externally. We need to +// hack the unsupported/Eigen/CXX11/Tensor header instantiation macros at +// specific places, so we need two copies of the hacked file, one for +// internal and one for external. +// If you have trouble simply undef out the reducer macro e.g. +// TFLITE_REDUCE_INSTANTIATIONS_GOOGLE, but be aware this will make +// the binary much bigger! +#define TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE +#define Eigen EigenForTFLite +#if defined(TFLITE_REDUCE_INSTANTIATIONS_GOOGLE) +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h" +#elif defined(TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE) +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h" +#else +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#endif + + +namespace Eigen { + +/** SpatialConvolution + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies a 2D convolution over a multichannel input image. + * + * The input parameter is expected to be a tensor with a rank of 3 or more + * (channels, height, width, and optionally others) + * The kernel parameter is expected to be a 4D tensor (filters, channels, + * kernel_height, kernel_width) + * The input and the kernel must both be in col-major layout. The result will + * also be in col-major layout. + * + * If col_in_stride, row_in_stride > 1, then applies convolution with holes + * (aka atrous convolution), sampling every col_in_stride, row_in_stride input + * pixels. + * + * The result can be assigned to a tensor of rank equal to the rank of the + * input. The dimensions of the result will be filters, height, width (and + * others if applicable). + * + * It is possible to swap the order of the width and height dimensions provided + * that the same order is used in the input, the kernel, and the output. + * + */ +template +EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE static const typename internal::conditional< + internal::traits::Layout == ColMajor, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, + 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp > > >, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, + 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp >, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel> > > >::type + SpatialConvolution(const Input& input, const Kernel& kernel, + const DenseIndex row_stride = 1, + const DenseIndex col_stride = 1, + const PaddingType padding_type = PADDING_SAME, + const DenseIndex row_in_stride = 1, + const DenseIndex col_in_stride = 1) { + typedef typename internal::traits::Index TensorIndex; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + in(input); + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + kern(kernel); + + EIGEN_STATIC_ASSERT( + internal::traits::Layout == internal::traits::Layout, + YOU_MADE_A_PROGRAMMING_MISTAKE); + const bool isColMajor = (internal::traits::Layout == ColMajor); + + const int NumDims = internal::traits::NumDimensions; + + // Number of filters to apply. This is the same as the output depth of the + // result + const TensorIndex kernelFilters = + isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; + // Number of channels. This is the same as the input depth. + const TensorIndex kernelChannels = + isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; + const TensorIndex kernelRows = + isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; + const TensorIndex kernelCols = + isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; + + const DenseIndex kernelRowsEff = + kernelRows + (kernelRows - 1) * (row_in_stride - 1); + const DenseIndex kernelColsEff = + kernelCols + (kernelCols - 1) * (col_in_stride - 1); + + array, 1> contract_dims; + contract_dims[0] = IndexPair(1, 0); + + const TensorIndex InputRows = + isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); + const TensorIndex InputCols = + isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); + + TensorIndex out_height; + TensorIndex out_width; + switch (padding_type) { + case PADDING_VALID: + out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) / + static_cast(row_stride)); + out_width = numext::ceil((InputCols - kernelColsEff + 1.f) / + static_cast(col_stride)); + break; + case PADDING_SAME: + out_height = numext::ceil(InputRows / static_cast(row_stride)); + out_width = numext::ceil(InputCols / static_cast(col_stride)); + break; + default: + // Initialize unused variables to avoid a compiler warning + out_height = 0; + out_width = 0; + eigen_assert(false && "unexpected padding"); + } + + // Molds the output of the patch extraction code into a 2d tensor: + // - the first dimension (dims[0]): the patch values to be multiplied with the + // kernels + // - the second dimension (dims[1]): everything else + DSizes pre_contract_dims; + if (isColMajor) { + pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; + pre_contract_dims[1] = out_height * out_width; + for (int i = 3; i < NumDims; ++i) { + pre_contract_dims[1] *= in.dimension(i); + } + } else { + pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; + pre_contract_dims[0] = out_height * out_width; + for (int i = 0; i < NumDims - 3; ++i) { + pre_contract_dims[0] *= in.dimension(i); + } + } + + // Molds the output of the contraction into the shape expected by the used + // (assuming this is ColMajor): + // - 1st dim: kernel filters + // - 2nd dim: output height + // - 3rd dim: output width + // - 4th dim and beyond: everything else including batch size + DSizes post_contract_dims; + if (isColMajor) { + post_contract_dims[0] = kernelFilters; + post_contract_dims[1] = out_height; + post_contract_dims[2] = out_width; + for (int i = 3; i < NumDims; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } else { + post_contract_dims[NumDims - 1] = kernelFilters; + post_contract_dims[NumDims - 2] = out_height; + post_contract_dims[NumDims - 3] = out_width; + for (int i = 0; i < NumDims - 3; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } + + DSizes kernel_dims; + if (isColMajor) { + kernel_dims[0] = kernelFilters; + kernel_dims[1] = kernelChannels * kernelRows * kernelCols; + } else { + kernel_dims[0] = kernelChannels * kernelRows * kernelCols; + kernel_dims[1] = kernelFilters; + } + // TODO(yangke): choose() is defined in TensorContraction.h -- consider + // moving it to somewhere more "common". + return + input + .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, + row_in_stride, col_in_stride, padding_type) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims) + .reshape(post_contract_dims); +} + +} // end namespace Eigen + +// clang-format on + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h new file mode 100644 index 0000000000000000000000000000000000000000..7f78f69360b1ebbfb08600c8bc427f1ba9d5244d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ + +#define EIGEN_USE_CUSTOM_THREAD_POOL +#define EIGEN_USE_THREADS + +// clang-format off + +#include + +#include +#include +#include +#include +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include + +#ifdef _WIN32 +#include +#elif defined(__APPLE__) +#include +#else +#include +#endif + + +// Because some programs may link Eigen in through other frameworks with +// different flags, we can run into multiple definition issues if we don't have +// a private namespace for our versions. This is a nasty hack, but a similar +// approach is used elsewhere to handle the problem, so it should be stable. +#define Eigen EigenForTFLite + +#include "Eigen/src/Core/util/StaticAssert.h" +#include "unsupported/Eigen/CXX11/Core" +#include "unsupported/Eigen/SpecialFunctions" + +#include "Eigen/src/Core/util/DisableStupidWarnings.h" + +#include "Eigen/Core" + +// Beware: the order of the include matters to some compilers. For example +// TensorIndexList.h should be included before TensorDimensions.h in order to +// use index lists to encode tensor dimensions when compiling with llvm. +// We're defining this ourselves rather than using the Eigen Tensor header file +// so that we can alter the macro definition of TENSOR_CONTRACTION_DISPATCH to +// reduce binary size. +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/ThreadPoolInterface.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorNonBlockingThreadPool.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStats.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMappers.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h" +#undef TENSOR_CONTRACTION_DISPATCH +#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ + if (this->m_lhs_inner_dim_contiguous && \ + this->m_rhs_inner_dim_contiguous && \ + !this->m_rhs_inner_dim_reordered) { \ + METHOD ARGS; \ + } else { \ + eigen_assert(false && "Unsupported contraction formats"); \ + } + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/Tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" + +#include "Eigen/src/Core/util/ReenableStupidWarnings.h" +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h new file mode 100644 index 0000000000000000000000000000000000000000..1d5c316194df0b87ee7eecbdd04bd5ce9e2e40b5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h @@ -0,0 +1,167 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is essentially unsupported/CXX11/Eigen/Tensor.h +// TODO(petewarden) - move this to a common location in Eigen itself. + +// clang-format off + + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ + + +#include "Eigen/Core" + +#if defined(EIGEN_USE_SYCL) +#undef min +#undef max +#undef isnan +#undef isinf +#undef isfinite +#include +#include +#include +#include +#include +#endif +#include +#include +#include + + + + + +#ifdef _WIN32 +typedef __int16 int16_t; +typedef unsigned __int16 uint16_t; +typedef __int32 int32_t; +typedef unsigned __int32 uint32_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +#include +#else +#include +#include +#endif + +#if __cplusplus > 199711 || EIGEN_COMP_MSVC >= 1900 +#include +#endif + +#ifdef _WIN32 +#include +#elif defined(__APPLE__) +#include +#else +#include +#endif + +// #if defined(EIGEN_USE_LIBXSMM) +// #include "libxsmm.h" +// #endif + +#ifdef EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/ThreadPool" +#endif + + +#include "Eigen/src/Core/util/DisableStupidWarnings.h" + +#include "unsupported/Eigen/SpecialFunctions" +#include "unsupported/Eigen/CXX11/src/util/CXX11Meta.h" +#include "unsupported/Eigen/CXX11/src/util/MaxSizeVector.h" + + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h" + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h" + +#undef TENSOR_CONTRACTION_DISPATCH +#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ + if (this->m_lhs_inner_dim_contiguous && \ + this->m_rhs_inner_dim_contiguous && \ + !this->m_rhs_inner_dim_reordered) { \ + METHOD ARGS; \ + } else { \ + eigen_assert(false && "Unsupported contraction formats"); \ + } + + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorScan.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h" +#include "unsupported/Eigen/CXX11/src/Tensor/Tensor.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMap.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorRef.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" + +#include "Eigen/src/Core/util/ReenableStupidWarnings.h" + + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..b3615f4658a1a70284cc9d386a868a87aa09819b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -0,0 +1,195 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace multithreaded_ops { + +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + Eigen::ThreadPool* pool_ = nullptr; +}; + +// We have a single global threadpool for all convolution operations. This means +// that inferences started from different threads may block each other, but +// since the underlying resource of CPU cores should be consumed by the +// operations anyway, it shouldn't affect overall performance. +const Eigen::ThreadPoolDevice& GetThreadPoolDevice() { + const int thread_count = 4; + static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count); + static EigenThreadPoolWrapper* thread_pool_wrapper = + new EigenThreadPoolWrapper(tp); + static Eigen::ThreadPoolDevice* device = + new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count); + return *device; +} + +// Shorthands for the types we need when interfacing with the EigenTensor +// library. +typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + EigenMatrix; +typedef Eigen::TensorMap< + Eigen::Tensor, + Eigen::Aligned> + ConstEigenMatrix; + +typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + EigenTensor; +typedef Eigen::TensorMap< + Eigen::Tensor, + Eigen::Aligned> + ConstEigenTensor; + +// Utility functions we need for the EigenTensor API. +template +struct MatMulConvFunctor { + // Computes on device "d": out = in0 * in1, where * is matrix + // multiplication. + void operator()( + const Device& d, EigenMatrix out, ConstEigenMatrix in0, + ConstEigenMatrix in1, + const Eigen::array, 1>& dim_pair) { + out.device(d) = in0.contract(in1, dim_pair); + } +}; + +template +class EigenTensorConvFunctor { + private: + Eigen::PaddingType TfLitePadding2EigenPadding(TfLitePadding padding) { + switch (padding) { + case kTfLitePaddingValid: + return Eigen::PADDING_VALID; + case kTfLitePaddingSame: + return Eigen::PADDING_SAME; + case kTfLitePaddingUnknown: + assert(false); // should never get here. + return Eigen::PADDING_VALID; + } + return Eigen::PADDING_SAME; // Prevent compiler warning about missing + // return + } + + public: + void operator()(const T* input_data, T* im2col_buffer, int input_batches, + int input_height, int input_width, int input_depth, + const T* filter_data, int filter_height, int filter_width, + int filter_count, int stride_rows, int stride_cols, + int pad_width, int pad_height, TfLitePadding padding, + T* output_data, int output_height, int output_width) { + const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice(); + + const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 && + stride_rows == 1 && stride_cols == 1); + if (is_1x1_kernel) { + // For 1x1 kernel, the 2D convolution is reduced to matrix + // multiplication. + const int conv_width = output_height * output_width; + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + EigenMatrix output(output_data, conv_width, filter_count); + ConstEigenMatrix input(input_data, conv_width, input_depth); + ConstEigenMatrix filter(filter_data, input_depth, filter_count); + MatMulConvFunctor()(device, output, input, + filter, dim_pair); + } else if (filter_height == input_height && filter_width == input_width && + pad_width == 0 && pad_height == 0) { + // If the input data and filter have the same height/width, + // the 2D convolution is reduced to matrix multiplication. + const int k = // Length of reduction dimension. + filter_width * filter_height * input_depth; + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + EigenMatrix output(output_data, 1, filter_count); + ConstEigenMatrix input(input_data, 1, k); + ConstEigenMatrix filter(filter_data, k, filter_count); + MatMulConvFunctor()(device, output, input, + filter, dim_pair); + } else { + EigenTensor output(output_data, input_batches, output_height, + output_width, filter_count); + ConstEigenTensor input(input_data, input_batches, input_height, + input_width, input_depth); + ConstEigenTensor filter(filter_data, filter_height, filter_width, + input_depth, filter_count); + output.device(device) = + Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows, + TfLitePadding2EigenPadding(padding)); + } + } +}; + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, TfLitePadding padding, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims, + float* im2col_data, const Dims<4>& im2col_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + EigenTensorConvFunctor conv_functor; + conv_functor(input_data, im2col_data, batches, input_height, input_width, + input_depth, filter_data, filter_height, filter_width, + output_depth, stride_height, stride_width, pad_height, pad_width, + padding, output_data, output_height, output_width); + + optimized_ops::AddBiasAndEvalActivationFunction( + bias_data, bias_dims, output_data, output_dims, output_activation_min, + output_activation_max); +} + +} // namespace multithreaded_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf0bdfb1fb875c4b54c55e25d4a17541507ecd4c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -0,0 +1,337 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" + +#ifdef USE_NEON + +#include +#define kFloatWeightsPerNeonLane 4 + +namespace tflite { +namespace tensor_utils { + +void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1)); + + // The arrays used to cache the vector. + float32x4_t* vector_cache_float32x4 = + new float32x4_t[(m_cols / kFloatWeightsPerNeonLane) * + sizeof(float32x4_t)]; + const int kUnrollSize = 2; + for (int b = 0; b < n_batch; b++) { + float* result_in_batch = result + b * m_rows * result_stride; + const float* vector_in_batch = vector + b * m_cols; + + const float* matrix_ptr0 = matrix; + // If there is only 1 row, we don't want to assign an illegal pointer. + const float* matrix_ptr1 = nullptr; + if (m_rows > 1) { + matrix_ptr1 = matrix + m_cols; + } + + // Cahce the vector. + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c); + } + + // Main matrix by vector multiplication loop, which handles two rows of + // matrix by vector multiplication. + for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) { + float32x4_t acc0_32x4 = vmovq_n_f32(0.0); + float32x4_t acc1_32x4 = vmovq_n_f32(0.0); + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + float32x4_t temp = vector_cache_float32x4[c >> 2]; + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); + float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c); + // Vector multiply-accumulate 4 float + acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); + acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp); + } + // Add the 4 intermediate sum values to get the final dot-prod value for + // this column. + *result_in_batch += + (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + + vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); + *(result_in_batch + result_stride) += + (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) + + vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3)); + for (int c = postamble_start; c < m_cols; c++) { + *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; + *(result_in_batch + result_stride) += + matrix_ptr1[c] * vector_in_batch[c]; + } + matrix_ptr0 += kUnrollSize * m_cols; + matrix_ptr1 += kUnrollSize * m_cols; + result_in_batch += kUnrollSize * result_stride; + } + for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) { + float32x4_t acc0_32x4 = vmovq_n_f32(0.0); + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + float32x4_t temp = vector_cache_float32x4[c >> 2]; + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); + // Vector multiply-accumulate 4 float + acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); + } + // Add the 4 intermediate sum values to get the final dot-prod value for + // this column. + *result_in_batch += + (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + + vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); + for (int c = postamble_start; c < m_cols; c++) { + *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; + } + matrix_ptr0 += m_cols; + result_in_batch += result_stride; + } + } + delete[] vector_cache_float32x4; +} + +void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + // Vector multiply 4 float + float32x4_t mul_32x4 = vmulq_f32(v1_f32x4, v2_f32x4); + // Save to result array. + vst1q_f32(&result[v], mul_32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] = vector1[v] * vector2[v]; + } +} + +void NeonVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + float32x4_t acc_32x4 = vld1q_f32(result + v); + // Vector multiply-accumulate 4 float + acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4); + // Save to result array. + vst1q_f32(&result[v], acc_32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] += vector1[v] * vector2[v]; + } +} + +void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + // The arrays used to cache the vector. + float32x4_t* vector_cache_float32x4 = + new float32x4_t[(v_size / kFloatWeightsPerNeonLane) * + sizeof(float32x4_t)]; + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v); + } + + float* result_ptr = result; + const float* batch_vector_ptr = batch_vector; + for (int b = 0; b < n_batch; b++) { + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load from memory to vectors. + float32x4_t result_f32x4 = vld1q_f32(result_ptr + v); + float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector_ptr + v); + // Multiply-accumulate. + result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4, + vector_cache_float32x4[v >> 2]); + // Store. + vst1q_f32(result_ptr + v, result_f32x4); + } + // Postamble loop + for (int v = postamble_start; v < v_size; v++) { + result_ptr[v] += vector[v] * batch_vector_ptr[v]; + } + // Update the pointers. + result_ptr += v_size; + batch_vector_ptr += v_size; + } + delete[] vector_cache_float32x4; +} + +void NeonSub1Vector(const float* vector, int v_size, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + float32x4_t one_f32x4 = vmovq_n_f32(1.0); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from the current pointers of the input column and + // subtract from 1. + float32x4_t v_f32x4 = vld1q_f32(vector + v); + float32x4_t result_f32x4 = vsubq_f32(one_f32x4, v_f32x4); + // Save to output. + vst1q_f32(result + v, result_f32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] = 1.0f - vector[v]; + } +} + +void NeonClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + // Replicate abs_limit and -abs_limit in two vectors. + const float32x4_t abs_limit_f32x4 = vmovq_n_f32(abs_limit); + const float32x4_t neg_abs_limit_f32x4 = vmovq_n_f32(-abs_limit); + + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load from memory to vector. + float32x4_t v_f32x4 = vld1q_f32(vector + v); + // Clip between abs_limit and -abs_limit. + float32x4_t result_f32x4 = vminq_f32(abs_limit_f32x4, v_f32x4); + result_f32x4 = vmaxq_f32(neg_abs_limit_f32x4, result_f32x4); + // Save to output. + vst1q_f32(result + v, result_f32x4); + } + // Postamble loop. + for (int v = postamble_start; v < v_size; v++) { + result[v] = (abs_limit < vector[v]) ? abs_limit : vector[v]; + result[v] = (-abs_limit > result[v]) ? -abs_limit : result[v]; + } +} + +float NeonVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + float32x4_t acc_32x4 = vmovq_n_f32(0.0); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + // Vector multiply-accumulate 4 float + acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4); + } + + float result = (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) + + vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3)); + // Postamble loop. + for (int v = postamble_start; v < v_size; v++) { + result += vector1[v] * vector2[v]; + } + return result; +} + +void NeonBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + float* result_ptr = result; + const float* vector1_ptr = vector1; + const float* vector2_ptr = vector2; + for (int b = 0; b < n_batch; b++) { + *result_ptr = NeonVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size); + vector1_ptr += v_size; + vector2_ptr += v_size; + result_ptr += result_stride; + } +} + +void NeonReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + const float* input_vector_ptr = input_vector; + for (int o = 0; o < output_size; o++) { + // If reduction_size is not divisible by kWeightsPerNeonLane, we cannot use + // the main vectorized loop, and we need to process sequentially. + // postamble_start shows the start index where this should happen. + const int postamble_start = + reduction_size - (reduction_size & (kFloatWeightsPerNeonLane - 1)); + float32x4_t sum_f32x4 = vmovq_n_f32(0.0); + for (int r = 0; r < postamble_start; r += kFloatWeightsPerNeonLane) { + float32x4_t v1_f32x4 = vld1q_f32(input_vector_ptr + r); + sum_f32x4 = vaddq_f32(sum_f32x4, v1_f32x4); + } + output_vector[o] += + (vgetq_lane_f32(sum_f32x4, 0) + vgetq_lane_f32(sum_f32x4, 1) + + vgetq_lane_f32(sum_f32x4, 2) + vgetq_lane_f32(sum_f32x4, 3)); + input_vector_ptr += postamble_start; + + // Postamble loop. + for (int r = postamble_start; r < reduction_size; r++) { + output_vector[o] += *input_vector_ptr++; + } + } +} + +void NeonVectorShiftLeft(float* vector, int v_size, float shift_value) { + // This variable keeps track of the next to the last index which is being + // copied to make sure we are not out of the vector boundary. + int last_index_copy = kFloatWeightsPerNeonLane; + int current_index_copy = 0; + while (last_index_copy < v_size) { + float32x4_t v_f32x4 = vld1q_f32(vector + current_index_copy + 1); + vst1q_f32(vector + current_index_copy, v_f32x4); + current_index_copy += kFloatWeightsPerNeonLane; + last_index_copy += kFloatWeightsPerNeonLane; + } + // Postamble loop. + for (int i = current_index_copy; i < v_size - 1; i++) { + vector[i] = vector[i + 1]; + } + vector[v_size - 1] = shift_value; +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3a4af87304eaf33489b38bd9b15ad9789e091d24 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ + +// TODO(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" + +namespace tflite { +namespace tensor_utils { + +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, + vector, n_batch, result, result_stride); +} + +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result); +} + +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + NEON_OR_PORTABLE(VectorVectorCwiseProductAccumulate, vector1, vector2, v_size, + result); +} + +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result) { + NEON_OR_PORTABLE(VectorBatchVectorCwiseProductAccumulate, vector, v_size, + batch_vector, n_batch, result); +} + +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size); +} + +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size, + n_batch, result, result_stride); +} + +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); +} + +void ApplySigmoidToVector(const float* vector, int v_size, float* result) { + PortableApplySigmoidToVector(vector, v_size, result); +} + +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result) { + PortableApplyActivationToVector(vector, v_size, activation, result); +} + +void CopyVector(const float* vector, int v_size, float* result) { + PortableCopyVector(vector, v_size, result); +} + +void Sub1Vector(const float* vector, int v_size, float* result) { + NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result); +} + +void ZeroVector(float* vector, int v_size) { + PortableZeroVector(vector, v_size); +} + +float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } + +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result); +} + +void VectorShiftLeft(float* vector, int v_size, float shift_value) { + NEON_OR_PORTABLE(VectorShiftLeft, vector, v_size, shift_value); +} + +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size, + reduction_size); +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..cd565c16a1ee7226f83c19f0020beed75e401497 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -0,0 +1,3715 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Make a local VectorMap typedef allowing to map a float array +// as a Eigen vector expression. The std::conditional here is to +// construct the suitable Eigen type for the constness of the +// data. Indeed, for const data, we need to produce +// Eigen::Map> +// and not the more straightforward +// Eigen::Map> +template +using VectorMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, 1>>, + Eigen::Map>>::type; + +template +VectorMap MapAsVector(Scalar* data, const Dims& dims) { + const int size = RequiredBufferSizeForDims(dims); + return VectorMap(data, size, 1); +} + +// Make a local VectorMap typedef allowing to map a float array +// as a Eigen matrix expression. The same explanation as for VectorMap +// above also applies here. +template +using MatrixMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithLastDimAsCols(Scalar* data, + const Dims& dims) { + const int cols = dims.sizes[N - 1]; + int rows = 1; + for (int d = 0; d < N - 1; d++) { + rows *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +using ArrayMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +ArrayMap MapAsArrayWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return ArrayMap(data, rows, cols); +} + +// TODO(b/62193649): this function is only needed as long +// as we have the --variable_batch hack. +template +MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, + const Dims& dims, + int rows) { + int cols = 1; + bool matched_rows = false; + for (int d = 0; d < N; d++) { + cols *= dims.sizes[d]; + if (cols == rows) { + matched_rows = true; + cols = 1; + } + } + TFLITE_DCHECK(matched_rows); + return MatrixMap(data, rows, cols); +} + +// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE +// BROADCASTING. +// +// NdArrayDesc describes the shape and memory layout of an N-dimensional +// rectangular array of numbers. +// +// NdArrayDesc is basically identical to Dims defined in types.h. +// However, as Dims is to be deprecated, this class exists as an adaptor +// to enable simple unoptimized implementations of element-wise broadcasting +// operations. +template +struct NdArrayDesc { + // The "extent" of each dimension. Indices along dimension d must be in the + // half-open interval [0, extents[d]). + int extents[N]; + + // The number of *elements* (not bytes) between consecutive indices of each + // dimension. + int strides[N]; +}; + +// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// ELEMENT-WISE BROADCASTING. +// +// Same as Offset(), except takes as NdArrayDesc instead of Dims. +inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, + int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); + return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + + i3 * desc.strides[3]; +} + +// Given the dimensions of the operands for an element-wise binary broadcast, +// adjusts them so that they can be directly iterated over with simple loops. +// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and +// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. +// +// This function assumes that the two input shapes are compatible up to +// broadcasting and the shorter one has already been prepended with 1s to be the +// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), +// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that +// Dims refer to shapes in reverse order. In this case, input0_dims will be +// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). +// +// When two shapes are compatible up to broadcasting, for each dimension d, +// the input extents are either equal, or one of them is 1. +// +// This function performs the following for each dimension d: +// - If the extents are equal, then do nothing since the loop that walks over +// both of the input arrays is correct. +// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 +// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows +// array0 to be referenced *at any index* in dimension d and still access the +// same slice. +template +inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, + const Dims& input1_dims, + NdArrayDesc* desc0_out, + NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + // Copy dims to desc. + for (int i = 0; i < N; ++i) { + desc0_out->extents[i] = input0_dims.sizes[i]; + desc0_out->strides[i] = input0_dims.strides[i]; + desc1_out->extents[i] = input1_dims.sizes[i]; + desc1_out->strides[i] = input1_dims.strides[i]; + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = ArraySize(input0_dims, i); + const int extent1 = ArraySize(input1_dims, i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } +} + +inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) { + for (int i = 0; i < 4; i++) { + if (dims1.sizes[i] != dims2.sizes[i]) { + return false; + } + } + return true; +} + +inline void AddBiasAndEvalActivationFunction(const float* bias_data, + const Dims<4>& bias_dims, + float* array_data, + const Dims<4>& array_dims, + float output_activation_min, + float output_activation_max) { +#ifdef USE_NEON + gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); + const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3]; + const int array_size = array_dims.sizes[3] * array_dims.strides[3]; + TFLITE_DCHECK_EQ((array_size % bias_size), 0); + float* array_ptr = array_data; + float* array_end_ptr = array_ptr + array_size; + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; array_ptr != array_end_ptr; array_ptr += bias_size) { + int i = 0; + for (; i <= bias_size - 16; i += 16) { + auto b0 = vld1q_f32(bias_data + i); + auto b1 = vld1q_f32(bias_data + i + 4); + auto b2 = vld1q_f32(bias_data + i + 8); + auto b3 = vld1q_f32(bias_data + i + 12); + auto a0 = vld1q_f32(array_ptr + i); + auto a1 = vld1q_f32(array_ptr + i + 4); + auto a2 = vld1q_f32(array_ptr + i + 8); + auto a3 = vld1q_f32(array_ptr + i + 12); + auto x0 = vaddq_f32(a0, b0); + auto x1 = vaddq_f32(a1, b1); + auto x2 = vaddq_f32(a2, b2); + auto x3 = vaddq_f32(a3, b3); + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + vst1q_f32(array_ptr + i, x0); + vst1q_f32(array_ptr + i + 4, x1); + vst1q_f32(array_ptr + i + 8, x2); + vst1q_f32(array_ptr + i + 12, x3); + } + for (; i <= bias_size - 4; i += 4) { + auto b = vld1q_f32(bias_data + i); + auto a = vld1q_f32(array_ptr + i); + auto x = vaddq_f32(a, b); + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + vst1q_f32(array_ptr + i, x); + } + for (; i < bias_size; i++) { + array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i], + output_activation_min, + output_activation_max); + } + } +#else // not NEON + gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); + const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3]; + const int array_size = array_dims.sizes[3] * array_dims.strides[3]; + TFLITE_DCHECK_EQ((array_size % bias_size), 0); + for (int array_offset = 0; array_offset < array_size; + array_offset += bias_size) { + for (int i = 0; i < bias_size; i++) { + array_data[array_offset + i] = ActivationFunctionWithMinMax( + array_data[array_offset + i] + bias_data[i], output_activation_min, + output_activation_max); + } + } +#endif +} + +// legacy, for compatibility with old checked-in code +template +void AddBiasAndEvalActivationFunction(const float* bias_data, + const Dims<4>& bias_dims, + float* array_data, + const Dims<4>& array_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims, + output_activation_min, + output_activation_max); +} + +template +void Gemm(const Eigen::MatrixBase& lhs, const Eigen::MatrixBase& rhs, + Eigen::MatrixBase* result) { + if (rhs.cols() == 1) { + gemmlowp::ScopedProfilingLabel label("GEMV"); + result->col(0).noalias() = lhs * rhs.col(0); + } else { + gemmlowp::ScopedProfilingLabel label("GEMM"); + result->noalias() = lhs * rhs; + } +} + +inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FullyConnected"); + // TODO(b/62193649): this convoluted shape computation (determining + // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows) + // is because the current --variable_batch hack consists in overwriting the + // 3rd dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + // When that is fixed, this should become: + // const auto input_matrix_map = + // MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + const int input_rows = ArraySize(weights_dims, 0); + const auto input_matrix_map = + MapAsMatrixWithGivenNumberOfRows(input_data, input_dims, input_rows); + const auto filter_matrix_map = + MapAsMatrixWithFirstDimAsRows(weights_data, weights_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map); + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims, output_activation_min, + output_activation_max); +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, const Dims<4>& weights_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data, + bias_dims, output_activation_min, output_activation_max, + output_data, output_dims); +} + +inline void preload_l1_stream(const uint8* ptr) { +#ifdef GEMMLOWP_ARM_64 + asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :); +#else + gemmlowp::Prefetch(ptr); +#endif +} + +#ifdef USE_NEON +inline void FullyConnectedAsGEMV( + const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, + const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset, + int32 output_multiplier, int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3), + 1); + const int input_size = input_dims.strides[3]; + const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); + static constexpr int kPeel = 4; + for (int k = 0; k < input_size; k += 64) { + preload_l1_stream(input_data + k); + } + for (int k = 0; k < kPeel * input_size; k += 64) { + preload_l1_stream(filter_data + k); + } + TFLITE_DCHECK(!(output_size % kPeel)); + const int32* bias_ptr = bias_data; + uint8* output_ptr = output_data; + for (int out = 0; out < output_size; out += kPeel) { + int32x4_t acc[kPeel]; + for (int k = 0; k < kPeel; k++) { + acc[k] = vdupq_n_s32(0); + } + const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); + const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset); + int in = 0; + for (; in <= input_size - 16; in += 16) { + const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); + uint8x16_t filter_val_u8[kPeel]; + for (int k = 0; k < kPeel; k++) { + const uint8* filter_ptr = filter_data + in + (out + k) * input_size; + filter_val_u8[k] = vld1q_u8(filter_ptr); + preload_l1_stream(filter_ptr + 64); + } + int16x8_t input_val[2]; + const uint8x8_t low = vget_low_u8(input_val_u8); + const uint8x8_t high = vget_high_u8(input_val_u8); + input_val[0] = vreinterpretq_s16_u16(vmovl_u8(low)); + input_val[1] = vreinterpretq_s16_u16(vmovl_u8(high)); + input_val[0] = vaddq_s16(input_val[0], input_offset_vec); + input_val[1] = vaddq_s16(input_val[1], input_offset_vec); + int16x8_t filter_val[kPeel][2]; + for (int k = 0; k < kPeel; k++) { + const uint8x8_t low = vget_low_u8(filter_val_u8[k]); + const uint8x8_t high = vget_high_u8(filter_val_u8[k]); + filter_val[k][0] = vreinterpretq_s16_u16(vmovl_u8(low)); + filter_val[k][1] = vreinterpretq_s16_u16(vmovl_u8(high)); + filter_val[k][0] = vaddq_s16(filter_val[k][0], filter_offset_vec); + filter_val[k][1] = vaddq_s16(filter_val[k][1], filter_offset_vec); + } + for (int p = 0; p < 2; p++) { + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k][p]), + vget_low_s16(input_val[p])); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k][p]), + vget_high_s16(input_val[p])); + } + } + } + for (; in <= input_size - 8; in += 8) { + const uint8x8_t input_val_u8 = vld1_u8(input_data + in); + uint8x8_t filter_val_u8[kPeel]; + for (int k = 0; k < kPeel; k++) { + const uint8* filter_ptr = filter_data + in + (out + k) * input_size; + filter_val_u8[k] = vld1_u8(filter_ptr); + } + int16x8_t input_val; + input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); + input_val = vaddq_s16(input_val, input_offset_vec); + int16x8_t filter_val[kPeel]; + for (int k = 0; k < kPeel; k++) { + filter_val[k] = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8[k])); + filter_val[k] = vaddq_s16(filter_val[k], filter_offset_vec); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k]), + vget_low_s16(input_val)); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k]), + vget_high_s16(input_val)); + } + } + if (in < input_size) { + int32 buf[4 * kPeel]; + for (int k = 0; k < 4; k++) { + vst1q_s32(buf + 4 * k, acc[k]); + } + for (; in < input_size; in++) { + int lane = (in + 8 - input_size) % 4; + const int32 input_val = input_data[in] + input_offset; + for (int k = 0; k < kPeel; k++) { + int32 filter_val = + filter_data[in + (out + k) * input_size] + filter_offset; + buf[lane + 4 * k] += filter_val * input_val; + } + } + for (int k = 0; k < 4; k++) { + acc[k] = vld1q_s32(buf + 4 * k); + } + } + + // Horizontally reduce accumulators + int32x2_t pairwise_reduced_acc[kPeel]; + for (int k = 0; k < kPeel; k++) { + pairwise_reduced_acc[k] = + vpadd_s32(vget_low_s32(acc[k]), vget_high_s32(acc[k])); + } + static_assert(kPeel == 4, "the code below currently assumes kPeel = 4"); + const int32x2_t reduced_lo = + vpadd_s32(pairwise_reduced_acc[0], pairwise_reduced_acc[1]); + const int32x2_t reduced_hi = + vpadd_s32(pairwise_reduced_acc[2], pairwise_reduced_acc[3]); + int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); + // Add bias values. + int32x4_t bias_vec = vld1q_s32(bias_ptr); + bias_ptr += 4; + reduced = vaddq_s32(reduced, bias_vec); + // Multiply by the fixed-point multiplier. + reduced = vqrdmulhq_n_s32(reduced, output_multiplier); + // Rounding-shift-right. + using gemmlowp::RoundingDivideByPOT; + reduced = RoundingDivideByPOT(reduced, output_shift); + // Add the output offset. + const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); + reduced = vaddq_s32(reduced, output_offset_vec); + // Narrow values down to 16 bit signed. + const int16x4_t res16 = vqmovn_s32(reduced); + // Narrow values down to 8 bit unsigned, saturating. + uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16)); + // Apply the clamping from the activation function + res8 = vmax_u8(res8, vdup_n_u8(output_activation_min)); + res8 = vmin_u8(res8, vdup_n_u8(output_activation_max)); + // Store results to destination. Assumes 32bit alignment. + vst1_lane_u32(reinterpret_cast(output_ptr), + vreinterpret_u32_u8(res8), 0); + output_ptr += kPeel; + } +} +#endif // USE_NEON + +struct GemmlowpOutputPipeline { + typedef gemmlowp::VectorMap + ColVectorMap; + typedef std::tuple< + gemmlowp::OutputStageBiasAddition, + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, + gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> + Pipeline; + static Pipeline Make(const int32* bias_data, int output_rows, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max) { + ColVectorMap bias_vector(bias_data, output_rows); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint + quantize_down_stage; + quantize_down_stage.result_offset_after_shift = output_offset; + quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; + quantize_down_stage.result_shift = output_shift; + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = output_activation_min; + clamp_stage.max = output_activation_max; + gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; + return std::make_tuple(bias_addition_stage, quantize_down_stage, + clamp_stage, saturating_cast_stage); + } +}; + +inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit"); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); +#ifdef USE_NEON + const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); + if (batches == 1 && !(output_size % 4)) { + return FullyConnectedAsGEMV( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, output_data, + output_dims); + } +#endif // USE_NEON + const int filter_rows = filter_dims.sizes[1]; + const int filter_cols = filter_dims.sizes[0]; + TFLITE_DCHECK_EQ(filter_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(filter_dims.sizes[3], 1); + const int output_rows = output_dims.sizes[0]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + + gemmlowp::MatrixMap filter_matrix( + filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( + input_data, filter_cols, batches, filter_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, batches, output_rows); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims, gemm_context); +} + +template +inline void ExtractPatchIntoBufferColumn( + const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth, + int stride_width, int stride_height, int pad_width, int pad_height, + int in_width, int in_height, int in_depth, int single_buffer_length, + int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) { + gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn"); + // This chunk of code reshapes all the inputs corresponding to + // output (b, h, w) to a column vector in conv_buffer(:, buffer_id). + const int kwidth_times_indepth = kwidth * in_depth; + const int inwidth_times_indepth = in_width * in_depth; + const int ih_ungated_start = h * stride_height - pad_height; + const int ih_ungated_end = (ih_ungated_start + kheight); + const int ih_end = std::min(ih_ungated_end, in_height); + const int iw_ungated_start = w * stride_width - pad_width; + const int iw_ungated_end = (iw_ungated_start + kwidth); + const int iw_end = std::min(iw_ungated_end, in_width); + // If the patch is off the edge of the input image, skip writing those rows + // and columns from the patch into the output array. + const int h_offset = std::max(0, -ih_ungated_start); + const int w_offset = std::max(0, -iw_ungated_start); + const int ih_start = std::max(0, ih_ungated_start); + const int iw_start = std::max(0, iw_ungated_start); + const int single_row_num = + std::min(kwidth - w_offset, in_width - iw_start) * in_depth; + const int output_row_offset = (buffer_id * single_buffer_length); + int out_offset = + output_row_offset + (h_offset * kwidth + w_offset) * in_depth; + int in_offset = Offset(input_dims, 0, iw_start, ih_start, b); + + // Express all of the calculations as padding around the input patch. + const int top_padding = h_offset; + const int bottom_padding = (ih_ungated_end - ih_end); + const int left_padding = w_offset; + const int right_padding = (iw_ungated_end - iw_end); + assert(single_row_num == + ((kwidth - (left_padding + right_padding)) * in_depth)); + + // Write out zeroes to the elements representing the top rows of the input + // patch that are off the edge of the input image. + if (top_padding > 0) { + const int top_row_elements = (top_padding * kwidth * in_depth); + memset(conv_buffer_data + output_row_offset, byte_zero, + (top_row_elements * sizeof(T))); + } + + // If the patch is on the interior of the input image horizontally, just copy + // over the rows sequentially, otherwise add zero padding at the start or end. + if ((left_padding == 0) && (right_padding == 0)) { + for (int ih = ih_start; ih < ih_end; ++ih) { + memcpy(conv_buffer_data + out_offset, in_data + in_offset, + single_row_num * sizeof(T)); + out_offset += kwidth_times_indepth; + in_offset += inwidth_times_indepth; + } + } else { + for (int ih = ih_start; ih < ih_end; ++ih) { + if (left_padding > 0) { + const int left_start = (out_offset - (left_padding * in_depth)); + memset(conv_buffer_data + left_start, byte_zero, + (left_padding * in_depth * sizeof(T))); + } + memcpy(conv_buffer_data + out_offset, in_data + in_offset, + single_row_num * sizeof(T)); + if (right_padding > 0) { + const int right_start = (out_offset + single_row_num); + memset(conv_buffer_data + right_start, byte_zero, + (right_padding * in_depth * sizeof(T))); + } + out_offset += kwidth_times_indepth; + in_offset += inwidth_times_indepth; + } + } + + // If the bottom of the patch falls off the input image, pad the values + // representing those input rows with zeroes. + if (bottom_padding > 0) { + const int bottom_row_elements = (bottom_padding * kwidth * in_depth); + const int bottom_start = + output_row_offset + + ((top_padding + (ih_end - ih_start)) * kwidth * in_depth); + memset(conv_buffer_data + bottom_start, byte_zero, + (bottom_row_elements * sizeof(T))); + } +} + +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, + int stride_height, int pad_width, int pad_height, int kheight, + int kwidth, uint8 byte_zero, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Im2col"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + + int buffer_id = 0; + // Loop over the output nodes. + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < output_height; ++h) { + for (int w = 0; w < output_width; ++w) { + ExtractPatchIntoBufferColumn( + input_dims, w, h, b, kheight, kwidth, stride_width, stride_height, + pad_width, pad_height, input_width, input_height, input_depth, + output_depth, buffer_id, input_data, output_data, byte_zero); + ++buffer_id; + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int kheight, int kwidth, + uint8 byte_zero, T* output_data, const Dims<4>& output_dims) { + Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, + kwidth, byte_zero, output_data, output_dims); +} + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + (void)im2col_data; + (void)im2col_dims; + gemmlowp::ScopedProfilingLabel label("Conv"); + + const float* gemm_input_data = nullptr; + const Dims<4>* gemm_input_dims = nullptr; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + Im2col(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_height, filter_width, 0, im2col_data, + im2col_dims); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else { + // TODO(aselle): We need to make sure to not send im2col if it is not + // needed. + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_dims = &input_dims; + } + + const auto im2col_matrix_map = + MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map); + + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims, output_activation_min, + output_activation_max); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride_width, + int stride_height, int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, + stride_width, stride_height, pad_width, pad_height, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, output_data, + output_dims, im2col_data, im2col_dims); +} + +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("Conv/8bit"); + + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + const uint8* gemm_input_data = nullptr; + const Dims<4>* gemm_input_dims = nullptr; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + const int input_zero_point = -input_offset; + TFLITE_DCHECK_GE(input_zero_point, 0); + TFLITE_DCHECK_LE(input_zero_point, 255); + Im2col(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_height, filter_width, input_zero_point, + im2col_data, im2col_dims); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else { + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_dims = &input_dims; + } + + const int gemm_input_rows = gemm_input_dims->sizes[0]; + const int gemm_input_cols = gemm_input_dims->sizes[1] * + gemm_input_dims->sizes[2] * + gemm_input_dims->sizes[3]; + const int filter_rows = filter_dims.sizes[3]; + const int filter_cols = + filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; + const int output_rows = output_dims.sizes[0]; + const int output_cols = + output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(output_cols, gemm_input_cols); + TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + gemmlowp::MatrixMap filter_matrix( + filter_data, filter_rows, filter_cols); + gemmlowp::MatrixMap input_matrix( + gemm_input_data, gemm_input_rows, gemm_input_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, output_cols); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +// legacy, for compatibility with old checked-in code +template +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, stride_height, + pad_width, pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, stride, pad_width, + pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +template +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthToSpace"); + + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + + const int output_depth = ArraySize(output_dims, 0); + const int batch_size = ArraySize(output_dims, 3); + + // Number of continuous values that we can copy in one interation. + const int stride = block_size * output_depth; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch); + for (int offset_h = 0; offset_h < block_size; ++offset_h) { + const T* src = input_ptr; + for (int in_w = 0; in_w < input_width; ++in_w) { + memcpy(output_data, src, stride * sizeof(T)); + output_data += stride; + src += input_depth; + } + input_ptr += stride; + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int kheight, int kwidth, + uint8 byte_zero, T* output_data, const Dims<4>& output_dims) { + Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, + kwidth, byte_zero, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void ConvAsGemm(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("ConvAsGemm"); + + const auto input_matrix_map = + MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map); + + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit"); + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + const int input_rows = input_dims.sizes[0]; + const int input_cols = + input_dims.sizes[1] * input_dims.sizes[2] * input_dims.sizes[3]; + const int filter_rows = filter_dims.sizes[3]; + const int filter_cols = + filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; + const int output_rows = output_dims.sizes[0]; + const int output_cols = + output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(output_cols, input_cols); + TFLITE_DCHECK_EQ(filter_cols, input_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + gemmlowp::MatrixMap filter_matrix( + filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( + input_data, filter_cols, output_cols, filter_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, output_cols, output_rows); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +template +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("SpaceToDepth"); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + + const int input_depth = ArraySize(input_dims, 0); + const int batch_size = ArraySize(input_dims, 3); + + // Number of continuous values that we can copy in one interation. + const int stride = block_size * input_depth; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int out_h = 0; out_h < output_height; ++out_h) { + T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch); + for (int offset_h = 0; offset_h < block_size; ++offset_h) { + T* dst = output_ptr; + for (int out_w = 0; out_w < output_width; ++out_w) { + memcpy(dst, input_data, stride * sizeof(T)); + input_data += stride; + dst += output_depth; + } + output_ptr += stride; + } + } + } +} + +template +void NonGlobalBatchNormalization( + const float* input_data, const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, const float* multiplier_data, + const Dims<4>& multiplier_dims, const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2, + offset_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1, + offset_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, x, y, 0)]) * + multiplier_data[Offset(multiplier_dims, c, x, y, 0)] + + offset_data[Offset(offset_dims, c, x, y, 0)]); + } + } + } + } +} + +template +void GlobalBatchNormalization(const float* input_data, + const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, + const float* multiplier_data, + const Dims<4>& multiplier_dims, + const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, 0, 0, 0)]) * + multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] + + offset_data[Offset(offset_dims, c, 0, 0, 0)]); + } + } + } + } +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu (not fused)"); + + const auto input = MapAsVector(input_data, input_dims); + auto output = MapAsVector(output_data, output_dims); + output = input.cwiseMax(0.0f); +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 1; + const float lower = -1; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 6; + const float lower = 0; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Normalization"); + static_assert(Ac == FusedActivationFunctionType::kNone, ""); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + float squared_l2_norm = 0; + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + squared_l2_norm += val * val; + } + float inverse_l2_norm = 1.0f / std::sqrt(squared_l2_norm); + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] * inverse_l2_norm; + } + } + } + } +} + +inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, + int* output_shift) { + *output_shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*output_shift; + } + TFLITE_DCHECK_GT(input, 0); + const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; + const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; + *output_shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + TFLITE_DCHECK_GE(input, (1 << 27)); + TFLITE_DCHECK_LT(input, (1 << 29)); + using gemmlowp::FixedPoint; + using gemmlowp::Rescale; + using gemmlowp::SaturatingRoundingMultiplyByPOT; + // Using 3 integer bits gives us enough room for the internal arithmetic in + // this Newton-Raphson iteration. + using F3 = FixedPoint; + using F0 = FixedPoint; + const F3 fixedpoint_input = F3::FromRaw(input >> 1); + const F3 fixedpoint_half_input = + SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); + const F3 fixedpoint_half_three = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); + // Newton-Raphson iteration + // Naive unoptimized starting guess: x = 1 + F3 x = F3::One(); + // Naive unoptimized number of iterations: 5 + for (int i = 0; i < 5; i++) { + const F3 x3 = Rescale<3>(x * x * x); + x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); + } + const F0 fixedpoint_half_sqrt_2 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); + x = x * fixedpoint_half_sqrt_2; + *output_inv_sqrt = x.raw(); + if (*output_shift < 0) { + *output_inv_sqrt <<= -*output_shift; + *output_shift = 0; + } +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(batches, 1); + TFLITE_DCHECK_EQ(height, 1); + TFLITE_DCHECK_EQ(width, 1); + int32 square_l2_norm = 0; + for (int i = 0; i < depth; i++) { + int32 diff = input_data[i] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int i = 0; i < depth; i++) { + int32 diff = input_data[i] - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + output_data[i] = static_cast(output_val); + } +} + +inline void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; +#ifdef USE_NEON + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; i <= size - 16; i += 16) { + auto a10 = vld1q_f32(input1_data + i); + auto a11 = vld1q_f32(input1_data + i + 4); + auto a12 = vld1q_f32(input1_data + i + 8); + auto a13 = vld1q_f32(input1_data + i + 12); + auto a20 = vld1q_f32(input2_data + i); + auto a21 = vld1q_f32(input2_data + i + 4); + auto a22 = vld1q_f32(input2_data + i + 8); + auto a23 = vld1q_f32(input2_data + i + 12); + auto x0 = vaddq_f32(a10, a20); + auto x1 = vaddq_f32(a11, a21); + auto x2 = vaddq_f32(a12, a22); + auto x3 = vaddq_f32(a13, a23); + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + vst1q_f32(output_data + i, x0); + vst1q_f32(output_data + i + 4, x1); + vst1q_f32(output_data + i + 8, x2); + vst1q_f32(output_data + i + 12, x3); + } + for (; i <= size - 4; i += 4) { + auto a1 = vld1q_f32(input1_data + i); + auto a2 = vld1q_f32(input2_data + i); + auto x = vaddq_f32(a1, a2); + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + vst1q_f32(output_data + i, x); + } +#endif // NEON + + for (; i < size; i++) { + auto x = input1_data[i] + input2_data[i]; + output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); + } +} + +// legacy, for compatibility with old checked-in code +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + gemmlowp::ScopedProfilingLabel label("Add/8bit"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; + TFLITE_DCHECK_GT(input1_offset, -256); + TFLITE_DCHECK_GT(input2_offset, -256); + TFLITE_DCHECK_LT(input1_offset, 256); + TFLITE_DCHECK_LT(input2_offset, 256); +#ifdef USE_NEON + for (; i <= size - 8; i += 8) { + const auto input1_val_original = vld1_u8(input1_data + i); + const auto input2_val_original = vld1_u8(input2_data + i); + const auto input1_val_s16 = + vreinterpretq_s16_u16(vmovl_u8(input1_val_original)); + const auto input2_val_s16 = + vreinterpretq_s16_u16(vmovl_u8(input2_val_original)); + const auto input1_val = + vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset)); + const auto input2_val = + vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset)); + const auto input1_val_high = vget_high_s16(input1_val); + const auto input1_val_low = vget_low_s16(input1_val); + const auto input2_val_high = vget_high_s16(input2_val); + const auto input2_val_low = vget_low_s16(input2_val); + auto x11 = vmovl_s16(input1_val_low); + auto x12 = vmovl_s16(input1_val_high); + auto x21 = vmovl_s16(input2_val_low); + auto x22 = vmovl_s16(input2_val_high); + const auto left_shift_dup = vdupq_n_s32(left_shift); + x11 = vshlq_s32(x11, left_shift_dup); + x12 = vshlq_s32(x12, left_shift_dup); + x21 = vshlq_s32(x21, left_shift_dup); + x22 = vshlq_s32(x22, left_shift_dup); + x11 = vqrdmulhq_n_s32(x11, input1_multiplier); + x12 = vqrdmulhq_n_s32(x12, input1_multiplier); + x21 = vqrdmulhq_n_s32(x21, input2_multiplier); + x22 = vqrdmulhq_n_s32(x22, input2_multiplier); + const auto input1_shift_dup = vdupq_n_s32(-input1_shift); + const auto input2_shift_dup = vdupq_n_s32(-input2_shift); + x11 = vshlq_s32(x11, input1_shift_dup); + x12 = vshlq_s32(x12, input1_shift_dup); + x21 = vshlq_s32(x21, input2_shift_dup); + x22 = vshlq_s32(x22, input2_shift_dup); + auto s1 = vaddq_s32(x11, x21); + auto s2 = vaddq_s32(x12, x22); + s1 = vqrdmulhq_n_s32(s1, output_multiplier); + s2 = vqrdmulhq_n_s32(s2, output_multiplier); + using gemmlowp::RoundingDivideByPOT; + s1 = RoundingDivideByPOT(s1, output_shift); + s2 = RoundingDivideByPOT(s2, output_shift); + const auto s1_narrowed = vmovn_s32(s1); + const auto s2_narrowed = vmovn_s32(s2); + const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), + vdupq_n_s16(output_offset)); + vst1_u8(output_data + i, vqmovun_s16(s)); + } +#endif // NEON + + for (; i < size; i++) { + const int32 input1_val = input1_offset + input1_data[i]; + const int32 input2_val = input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = std::min( + output_activation_max, std::max(output_activation_min, raw_output)); + output_data[i] = static_cast(clamped_output); + } +} + +template +void Add(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() + input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() + scalar; + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar + input2_map.array(); + } else { + // Should not come here. + TFLITE_DCHECK(false); + } +} + +// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, + input1_multiplier, input1_shift, input2_data, input2_dims, + input2_offset, input2_multiplier, input2_shift, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; +#ifdef USE_NEON + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; i <= size - 16; i += 16) { + auto a10 = vld1q_f32(input1_data + i); + auto a11 = vld1q_f32(input1_data + i + 4); + auto a12 = vld1q_f32(input1_data + i + 8); + auto a13 = vld1q_f32(input1_data + i + 12); + auto a20 = vld1q_f32(input2_data + i); + auto a21 = vld1q_f32(input2_data + i + 4); + auto a22 = vld1q_f32(input2_data + i + 8); + auto a23 = vld1q_f32(input2_data + i + 12); + auto x0 = vmulq_f32(a10, a20); + auto x1 = vmulq_f32(a11, a21); + auto x2 = vmulq_f32(a12, a22); + auto x3 = vmulq_f32(a13, a23); + + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + + vst1q_f32(output_data + i, x0); + vst1q_f32(output_data + i + 4, x1); + vst1q_f32(output_data + i + 8, x2); + vst1q_f32(output_data + i + 12, x3); + } + for (; i <= size - 4; i += 4) { + auto a1 = vld1q_f32(input1_data + i); + auto a2 = vld1q_f32(input2_data + i); + auto x = vmulq_f32(a1, a2); + + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + + vst1q_f32(output_data + i, x); + } +#endif // NEON + + for (; i < size; i++) { + auto x = input1_data[i] * input2_data[i]; + output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); + } +} + +// legacy, for compatibility with old checked-in code +template +void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +void Mul(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() * input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() * scalar; + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar * input2_map.array(); + } else { + // Should not come here. + TFLITE_DCHECK(false); + } +} + +// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +// TODO(benoitjacob): BroadcastMul is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +template +void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 unclamped_result = + output_offset + + MultiplyByQuantizedMultiplierSmallerThanOne( + input1_val * input2_val, output_multiplier, output_shift); + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, unclamped_result)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + +template +void Concatenation(int concat_dim, const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Concatenation"); + int concat_size = 0; + for (int i = 0; i < inputs_count; i++) { + for (int j = 0; j < 4; j++) { + if (j != concat_dim) { + MatchingArraySize(*input_dims[i], j, output_dims, j); + } + } + concat_size += ArraySize(*input_dims[i], concat_dim); + } + TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + // for now we dont have a model with a Concatenation + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + int outer_size = 1; + for (int i = concat_dim + 1; i < 4; i++) { + outer_size *= output_dims.sizes[i]; + } + Scalar* output_ptr = output_data; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < inputs_count; ++i) { + const int copy_size = + input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; + memcpy(output_ptr, input_data[i] + k * copy_size, + copy_size * sizeof(Scalar)); + output_ptr += copy_size; + } + } +} + +template +void DepthConcatenation(const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + Concatenation(0, input_data, input_dims, inputs_count, + output_data, output_dims); +} + +inline void LstmCell(const float* input_data, const Dims<4>& input_dims, + const float* prev_activ_data, + const Dims<4>& prev_activ_dims, const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, const float* prev_state_data, + const Dims<4>& prev_state_dims, float* output_state_data, + const Dims<4>& output_state_dims, float* output_activ_data, + const Dims<4>& output_activ_dims, float* concat_temp_data, + const Dims<4>& concat_temp_dims, float* activ_temp_data, + const Dims<4>& activ_temp_dims) { + gemmlowp::ScopedProfilingLabel label("LstmCell"); + MatchingArraySize( // batches + input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, output_state_dims, + 3, output_activ_dims, 3); + MatchingArraySize( // height + input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, output_state_dims, + 2, output_activ_dims, 2); + MatchingArraySize( // width + input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, output_state_dims, + 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + + // Concatenate prev_activ and input data together + std::vector concat_input_arrays_data; + std::vector const*> concat_input_arrays_dims; + concat_input_arrays_data.push_back(input_data); + concat_input_arrays_data.push_back(prev_activ_data); + concat_input_arrays_dims.push_back(&input_dims); + concat_input_arrays_dims.push_back(&prev_activ_dims); + Concatenation( + 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]), + concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims); + + // Fully connected + FullyConnected( + concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data, + bias_dims, activ_temp_data, activ_temp_dims); + + // Map raw arrays to Eigen arrays so we can use Eigen's optimized array + // operations. + ArrayMap activ_temp_map = + MapAsArrayWithFirstDimAsRows(activ_temp_data, activ_temp_dims); + auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth, + activ_temp_map.cols()); + ArrayMap prev_state_map = + MapAsArrayWithFirstDimAsRows(prev_state_data, prev_state_dims); + ArrayMap output_state_map = + MapAsArrayWithFirstDimAsRows(output_state_data, output_state_dims); + ArrayMap output_activ_map = + MapAsArrayWithFirstDimAsRows(output_activ_data, output_activ_dims); + + // Combined memory state and final output calculation + gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput"); + output_state_map = + input_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + new_input_sm.tanh() + + forget_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + prev_state_map; + output_activ_map = + output_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + output_state_map.tanh(); +} + +template +void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, + int outputs_count, Scalar* const* output_data, + const Dims<4>* const* output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowSplit"); + TFLITE_DCHECK_GE(outputs_count, 1); + for (int i = 0; i < outputs_count; i++) { + /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3); + /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); + /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); + } + const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3); + const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2); + const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + // for now we dont have a model with a TensorFlowSplit + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + const int whb = width * height * batches; + const Scalar* input_ptr = input_data; + for (int k = 0; k < whb; k++) { + for (int i = 0; i < outputs_count; ++i) { + memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr, + output_dims[i]->sizes[0] * sizeof(Scalar)); + input_ptr += output_dims[i]->sizes[0]; + } + } +} + +inline int NodeOffset(int b, int h, int w, int height, int width) { + return (b * height + h) * width + w; +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("AveragePool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + // TODO(benoitjacob) make this a proper reference impl without Eigen! + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // TODO(benoitjacob) get rid of the dynamic memory allocation here! + Eigen::VectorXf out_count(out_mat.cols()); + out_count.setZero(); + // Prefill the output to 0. + out_mat.setZero(); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + int hpad = h + pad_height; + int wpad = w + pad_width; + int h_start = + (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int h_end = std::min(hpad / stride_height + 1, output_height); + int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_end = std::min(wpad / stride_width + 1, output_width); + // compute elementwise sum + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + int out_offset = NodeOffset(b, ph, pw, output_height, output_width); + out_mat.col(out_offset) += + in_mat.col(NodeOffset(b, h, w, input_height, input_width)); + out_count(out_offset)++; + } + } + } + } + } + // Divide the output by the actual number of elements being averaged over + TFLITE_DCHECK_GT(out_count.minCoeff(), 0); + out_mat.array().rowwise() /= out_count.transpose().array(); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + for (int x = 0; x < output_width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + output_data[Offset(output_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("AveragePool/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + const int filter_count = + (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start); + // 1280 required by Inception v3 + static constexpr int kAccBufferMaxSize = 2048; + TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); + uint16 acc[kAccBufferMaxSize]; + memset(acc, 0, depth * sizeof(acc[0])); + const uint8* input_ptr = + input_data + input_dims.strides[1] * in_x_origin + + input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + for (int fy = filter_y_start; fy < filter_y_end; fy++) { + const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + + filter_x_start * input_dims.strides[1]; + for (int fx = filter_x_start; fx < filter_x_end; fx++) { + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint16x8_t acc_reg[2]; + for (int i = 0; i < 2; i++) { + acc_reg[i] = vld1q_u16(acc + channel + 8 * i); + } + uint8x16_t input_reg = vld1q_u8(input_row_ptr); + input_row_ptr += 16; + acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg)); + acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg)); + for (int i = 0; i < 2; i++) { + vst1q_u16(acc + channel + 8 * i, acc_reg[i]); + } + } + for (; channel <= depth - 8; channel += 8) { + uint16x8_t acc_reg = vld1q_u16(acc + channel); + uint8x8_t input_reg = vld1_u8(input_row_ptr); + input_row_ptr += 8; + acc_reg = vaddw_u8(acc_reg, input_reg); + vst1q_u16(acc + channel, acc_reg); + } +#endif + for (; channel < depth; ++channel) { + acc[channel] += *input_row_ptr++; + } + } + } + uint8* output_ptr = + output_data + Offset(output_dims, 0, out_x, out_y, batch); + int channel = 0; +#ifdef USE_NEON +#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ + if (filter_count == FILTER_COUNT) { \ + for (; channel <= depth - 8; channel += 8) { \ + uint16 buf[8]; \ + for (int i = 0; i < 8; i++) { \ + buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \ + } \ + uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \ + buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \ + buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \ + vst1_u8(output_ptr + channel, buf8); \ + } \ + } + AVGPOOL_DIVIDING_BY(9) + AVGPOOL_DIVIDING_BY(15) +#undef AVGPOOL_DIVIDING_BY + for (; channel <= depth - 8; channel += 8) { + uint16 buf[8]; + for (int i = 0; i < 8; i++) { + buf[i] = (acc[channel + i] + filter_count / 2) / filter_count; + } + uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); + buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); + buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); + vst1_u8(output_ptr + channel, buf8); + } +#endif + for (; channel < depth; ++channel) { + uint16 a = (acc[channel] + filter_count / 2) / filter_count; + a = std::max(a, output_activation_min); + a = std::min(a, output_activation_max); + output_ptr[channel] = static_cast(a); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("MaxPool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // Prefill the output to minimum representable float value + out_mat.setConstant(std::numeric_limits::lowest()); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + int hpad = h + pad_height; + int wpad = w + pad_width; + int h_start = + (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int h_end = std::min(hpad / stride_height + 1, output_height); + int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_end = std::min(wpad / stride_width + 1, output_width); + // compute elementwise sum + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + int out_offset = NodeOffset(b, ph, pw, output_height, output_width); + out_mat.col(out_offset) = + out_mat.col(out_offset) + .cwiseMax(in_mat.col( + NodeOffset(b, h, w, input_height, input_width))); + } + } + } + } + } + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + for (int x = 0; x < output_width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + output_data[Offset(output_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("MaxPool/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + // 2048 required by Inception v3 + static constexpr int kAccBufferMaxSize = 2048; + TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); + uint8 acc[kAccBufferMaxSize]; + memset(acc, 0, depth * sizeof(acc[0])); + const uint8* input_ptr = + input_data + input_dims.strides[1] * in_x_origin + + input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + for (int fy = filter_y_start; fy < filter_y_end; fy++) { + const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + + filter_x_start * input_dims.strides[1]; + for (int fx = filter_x_start; fx < filter_x_end; fx++) { + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint8x16_t acc_reg = vld1q_u8(acc + channel); + uint8x16_t input_reg = vld1q_u8(input_row_ptr); + input_row_ptr += 16; + acc_reg = vmaxq_u8(acc_reg, input_reg); + vst1q_u8(acc + channel, acc_reg); + } + + for (; channel <= depth - 8; channel += 8) { + uint8x8_t acc_reg = vld1_u8(acc + channel); + uint8x8_t input_reg = vld1_u8(input_row_ptr); + input_row_ptr += 8; + acc_reg = vmax_u8(acc_reg, input_reg); + vst1_u8(acc + channel, acc_reg); + } +#endif + for (; channel < depth; ++channel) { + acc[channel] = std::max(acc[channel], *input_row_ptr++); + } + } + } + uint8* output_ptr = + output_data + Offset(output_dims, 0, out_x, out_y, batch); + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint8x16_t a = vld1q_u8(acc + channel); + a = vminq_u8(a, vdupq_n_u8(output_activation_max)); + a = vmaxq_u8(a, vdupq_n_u8(output_activation_min)); + vst1q_u8(output_ptr + channel, a); + } + for (; channel <= depth - 8; channel += 8) { + uint8x8_t a = vld1_u8(acc + channel); + a = vmin_u8(a, vdup_n_u8(output_activation_max)); + a = vmax_u8(a, vdup_n_u8(output_activation_min)); + vst1_u8(output_ptr + channel, a); + } +#endif + for (; channel < depth; ++channel) { + uint8 a = acc[channel]; + a = std::max(a, output_activation_min); + a = std::min(a, output_activation_max); + output_ptr[channel] = static_cast(a); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Pool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + // Actually carry out L2 Pool. Code is written in forward mode: we go through + // the input values once, and write to all the pooled regions that it maps to. + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + Eigen::VectorXf in_square(in_mat.rows()); + Eigen::VectorXf out_count(out_mat.cols()); + out_count.setZero(); + // Prefill the output to 0. + out_mat.setZero(); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + const int hpad = h + pad_height; + const int wpad = w + pad_width; + const int h_start = (hpad < filter_height) + ? 0 + : (hpad - filter_height) / stride_height + 1; + const int h_end = std::min(hpad / stride_height + 1, output_height); + const int w_start = (wpad < filter_width) + ? 0 + : (wpad - filter_width) / stride_width + 1; + const int w_end = std::min(wpad / stride_width + 1, output_width); + // pre-compute square + const int in_offset = w + input_width * (h + input_height * b); + in_square = + in_mat.col(in_offset).array() * in_mat.col(in_offset).array(); + // compute elementwise sum of squares + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + const int out_offset = pw + output_width * (ph + output_height * b); + out_mat.col(out_offset) += in_square; + out_count(out_offset)++; + } + } + } + } + } + + out_count = out_count.array().inverse(); + out_mat = + (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt(); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void LocalResponseNormalization(const float* input_data, + const Dims<4>& input_dims, int range, + float bias, float alpha, float beta, + float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization"); + /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + // Carry out local response normalization, vector by vector. + // Since the data are stored column major, making row-wise operation + // probably not memory efficient anyway, we do an explicit for loop over + // the columns. + const int double_range = range * 2; + Eigen::VectorXf padded_square(data_in.rows() + double_range); + padded_square.setZero(); + for (int r = 0; r < data_in.cols(); ++r) { + // Do local response normalization for data_in(:, r) + // first, compute the square and store them in buffer for repeated use + padded_square.block(range, 0, data_in.rows(), 1) = + data_in.col(r).cwiseProduct(data_in.col(r)) * alpha; + // Then, compute the scale and writes them to data_out + float accumulated_scale = 0; + for (int i = 0; i < double_range; ++i) { + accumulated_scale += padded_square(i); + } + for (int i = 0; i < data_in.rows(); ++i) { + accumulated_scale += padded_square(i + double_range); + data_out(i, r) = bias + accumulated_scale; + accumulated_scale -= padded_square(i); + } + } + + // In a few cases, the pow computation could benefit from speedups. + if (beta == 1) { + data_out.array() = data_in.array() * data_out.array().inverse(); + } else if (beta == 0.5) { + data_out.array() = data_in.array() * data_out.array().sqrt().inverse(); + } else { + data_out.array() = data_in.array() * data_out.array().pow(-beta); + } +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Softmax"); + /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // Compute the exponential first, removing the max coefficient for numerical + // stability. + out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta; + // We are separating out the exp function so that exp can be vectorized. + out_mat = out_mat.array().exp(); + // Normalize to get the activations. + Eigen::Array scale = + out_mat.array().colwise().sum().inverse(); + out_mat.array().rowwise() *= scale; +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + // The representation chosen for the input to the exp() function is Q5.26. + // We need to leave extra space since values that we skip might be as large as + // -32 before multiplying by input_beta_multiplier, and therefore as large as + // -16 afterwards. Note that exp(-8) is definitely not insignificant to + // accumulation, but exp(-16) definitely is. + static const int kScaledDiffIntegerBits = 5; + static const int kAccumulationIntegerBits = 12; + using FixedPointScaledDiff = + gemmlowp::FixedPoint; + using FixedPointAccum = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + + gemmlowp::ScopedProfilingLabel label("Softmax"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int x = 0; x < width; ++x) { + for (int y = 0; y < height; ++y) { + uint8 max_in_row = 0; + for (int c = 0; c < depth; ++c) { + max_in_row = + std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]); + } + + FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + sum_of_exps = + sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); + } + } + + int32 fixed_sum_of_exps = sum_of_exps.raw(); + // TODO(starka): Use a NEON intrinsic like vclzq_u32 instead. + int headroom_plus_one = + __builtin_clz(static_cast(fixed_sum_of_exps)); + // This is the number of bits to the left of the binary point above 1.0. + // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and + // no later adjustment will be needed. + int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; + int32 shifted_sum_minus_one = static_cast( + (static_cast(fixed_sum_of_exps) << headroom_plus_one) - + (static_cast(1) << 31)); + + FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( + FixedPoint0::FromRaw(shifted_sum_minus_one)); + + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + + FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); + int32 unsat_output = gemmlowp::RoundingDivideByPOT( + (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + + output_data[Offset(output_dims, c, x, y, b)] = + std::max(std::min(unsat_output, 255), 0); + + } else { + output_data[Offset(output_dims, c, x, y, b)] = 0; + } + } + } + } + } +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Logistic"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = + input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op()); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Logistic"); + /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* height */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* width */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0); + const int size = RequiredBufferSizeForDims(input_dims); + + int c = 0; +#ifdef USE_NEON + // Handle 16 values at a time + for (; c <= size - 16; c += 16) { + // Read input uint8 values, cast to int16 and subtract input_zero_point + uint8x16_t input_val_u8 = vld1q_u8(input_data + c); + int16x8_t input_val_centered_0 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + int16x8_t input_val_centered_1 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + + // Prepare the bit masks that we will use at the end to implement the logic + // that was expressed in the scalar code with branching: + // if (input_val_centered < -input_range_radius) { + // output_val = 0; + // } else if (input_val_centered > input_range_radius) { + // output_val = 255; + // } else { + // ... + uint16x8_t mask_rightclamp_0 = + vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_rightclamp_1 = + vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_leftclamp_0 = + vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius)); + uint16x8_t mask_leftclamp_1 = + vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius)); + uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8), + vshrn_n_u16(mask_rightclamp_1, 8)); + uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8), + vshrn_n_u16(mask_leftclamp_1, 8)); + + // This performs what is expressed in the scalar code as + // const int32 input_val_rescaled = + // MultiplyByQuantizedMultiplierGreaterThanOne( + // input_val_centered, input_multiplier, input_left_shift); + int32x4_t input_val_rescaled_0 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_1 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_2 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_3 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + input_val_rescaled_0 = + vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier); + input_val_rescaled_1 = + vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier); + input_val_rescaled_2 = + vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier); + input_val_rescaled_3 = + vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier); + + // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4_0 = + FixedPoint4::FromRaw(input_val_rescaled_0); + const FixedPoint4 input_val_f4_1 = + FixedPoint4::FromRaw(input_val_rescaled_1); + const FixedPoint4 input_val_f4_2 = + FixedPoint4::FromRaw(input_val_rescaled_2); + const FixedPoint4 input_val_f4_3 = + FixedPoint4::FromRaw(input_val_rescaled_3); + const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0); + const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1); + const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2); + const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3); + + // Divide by 2^23 as in the scalar code + using gemmlowp::RoundingDivideByPOT; + int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23); + int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23); + int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23); + int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23); + + // Cast output values to uint8, saturating + int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0), + vqmovn_s32(output_val_s32_1)); + int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2), + vqmovn_s32(output_val_s32_3)); + uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0), + vqmovun_s16(output_val_s16_1)); + + // Perform the bit-masking with the bit masks computed at the beginning, + // see the comment there. + output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp); + output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp); + + // Store back to memory + vst1q_u8(output_data + c, output_val_u8); + } +#endif + // Leftover loop: handle one value at a time with scalar code. + for (; c < size; ++c) { + const uint8 input_val_u8 = input_data[c]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered < -input_range_radius) { + output_val = 0; + } else if (input_val_centered > input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[c] = output_val; + } +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Tanh"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = input_map.array().tanh(); +} + +inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, + int32 zero_point, double scale, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Dequantize"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int32 val = input_data[Offset(input_dims, c, x, y, b)]; + float result = static_cast(scale * (val - zero_point)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, + float rmin, float rmax, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FakeQuant"); + + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + TFLITE_DCHECK_LE(rmin, 0.); + TFLITE_DCHECK_GE(rmax, 0.); + + // Determine quantization parameters: zero_point, scale. + using Integer = uint8; + const Integer qmin = std::numeric_limits::min(); + const Integer qmax = std::numeric_limits::max(); + const float qmin_float = qmin; + const float qmax_float = qmax; + int32 zero_point = 0; + float scale = 0.f; + // If rmin==rmax, both must be zero per the above assertion, + // so we are done. + if (rmin != rmax) { + // First determine the scale. + scale = (rmax - rmin) / (qmax_float - qmin_float); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const float zero_point_from_min = qmin_float - rmin / scale; + const float zero_point_from_max = qmax_float - rmax / scale; + const float zero_point_from_min_error = + std::abs(qmin_float) + std::abs(rmin / scale); + const float zero_point_from_max_error = + std::abs(qmax_float) + std::abs(rmax / scale); + + const float zero_point_float = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + if (zero_point_float < qmin_float) { + zero_point = qmin; + } else if (zero_point_float > qmax_float) { + zero_point = qmax; + } else { + zero_point = static_cast(TfLiteRound(zero_point_float)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + TFLITE_DCHECK_GE(zero_point, qmin); + TFLITE_DCHECK_LE(zero_point, qmax); + } + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const float src_val = input_data[Offset(input_dims, c, x, y, b)]; + const float unclamped_quantized_val = + TfLiteRound(zero_point + src_val / scale); + const float quantized_val = std::min( + qmax_float, std::max(qmin_float, unclamped_quantized_val)); + const float dst_val = scale * (quantized_val - zero_point); + output_data[Offset(output_dims, c, x, y, b)] = dst_val; + } + } + } + } +} + +template +inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, + DstT* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Cast"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = input_map.array().template cast(); +} + +inline void Floor(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Floor"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = Eigen::floor(input_map.array()); +} + +template +inline void Gather(const T* input_data, const Dims<4>& input_dims, + int input_rank, const int32* coords_data, + const Dims<4>& coords_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Gather"); + + TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); + int stride = input_dims.strides[input_rank - 1]; + T* out = output_data; + + for (int i = 0; i < coords_dims.sizes[0]; i++) { + TFLITE_DCHECK_GE(coords_data[i], 0); + TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); + const T* in = input_data + coords_data[i] * stride; + memcpy(out, in, sizeof(T) * stride); + out += stride; + } +} + +#ifdef USE_NEON +inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, + float scale, float* output_ptr) { + int ic = 0; + // Handle 32 input channels at a time. + for (; ic <= depth - 32; ic += 32) { + float32x4x2_t input[4]; + for (int i = 0; i < 4; i++) { + input[i].val[0] = vld1q_f32(input_ptr + 8 * i); + input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4); + } + float32x4x2_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i].val[0] = vld1q_f32(output_ptr + 8 * i); + acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4); + } + for (int i = 0; i < 4; i++) { + acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale); + acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale); + } + for (int i = 0; i < 4; i++) { + vst1q_f32(output_ptr, acc[i].val[0]); + vst1q_f32(output_ptr + 4, acc[i].val[1]); + output_ptr += 8; + } + input_ptr += 32; + } + // Handle 16 input channels at a time. + for (; ic <= depth - 16; ic += 16) { + float32x4x2_t input[2]; + for (int i = 0; i < 2; i++) { + input[i].val[0] = vld1q_f32(input_ptr + 8 * i); + input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4); + } + float32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_f32(output_ptr + 8 * i); + acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4); + } + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale); + acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale); + } + for (int i = 0; i < 2; i++) { + vst1q_f32(output_ptr, acc[i].val[0]); + vst1q_f32(output_ptr + 4, acc[i].val[1]); + output_ptr += 8; + } + input_ptr += 16; + } + // Handle 8 input channels at a time. + for (; ic <= depth - 8; ic += 8) { + float32x4x2_t input; + input.val[0] = vld1q_f32(input_ptr); + input.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t acc; + acc.val[0] = vld1q_f32(output_ptr); + acc.val[1] = vld1q_f32(output_ptr + 4); + acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale); + acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale); + + vst1q_f32(output_ptr, acc.val[0]); + vst1q_f32(output_ptr + 4, acc.val[1]); + + input_ptr += 8; + output_ptr += 8; + } + // Handle 4 input channels at a time. + for (; ic <= depth - 4; ic += 4) { + float32x4_t input = vld1q_f32(input_ptr); + float32x4_t acc = vld1q_f32(output_ptr); + + acc = vmlaq_n_f32(acc, input, scale); + vst1q_f32(output_ptr, acc); + + input_ptr += 4; + output_ptr += 4; + } + // Handle 1 input channel at a time. + for (; ic < depth; ic++) { + *output_ptr += *input_ptr * scale; + output_ptr++; + input_ptr++; + } +} +#else +inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, + float scale, float* output_ptr) { + for (int32 i = 0; i < depth; i++) { + *output_ptr += *input_ptr * scale; + output_ptr++; + input_ptr++; + } +} +#endif + +inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, + int32 x, int32 y, int32 depth, int32 batch, + const float* input_data, + const Dims<4>& input_dims, + float* output_data, + const Dims<4>& output_dims) { + const int32 input_width = ArraySize(input_dims, 1); + const int32 output_width = ArraySize(output_dims, 1); + + const int32 input_x_offset = (x1 - x0) * depth; + const int32 input_y_offset = (y1 - y0) * depth * input_width; + const int32 output_x_offset = depth; + const int32 output_y_offset = depth * output_width; + +#ifdef USE_NEON + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(x1 >= x0); + TFLITE_DCHECK(y1 >= y0); + + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= depth - 8; ic += 8) { + const float* input_ptr = nullptr; + + float32x4x2_t x0y0; + input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)]; + x0y0.val[0] = vld1q_f32(input_ptr); + x0y0.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x1y0; + input_ptr += input_x_offset; + x1y0.val[0] = vld1q_f32(input_ptr); + x1y0.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x0y1; + input_ptr += -input_x_offset + input_y_offset; + x0y1.val[0] = vld1q_f32(input_ptr); + x0y1.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x1y1; + input_ptr += input_x_offset; + x1y1.val[0] = vld1q_f32(input_ptr); + x1y1.val[1] = vld1q_f32(input_ptr + 4); + + // Top left corner. + float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)]; + vst1q_f32(output_ptr, x0y0.val[0]); + vst1q_f32(output_ptr + 4, x0y0.val[1]); + + // Top right corner. + output_ptr += output_x_offset; + float32x4x2_t tr; + tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]); + tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]); + tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f); + tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f); + + vst1q_f32(output_ptr, tr.val[0]); + vst1q_f32(output_ptr + 4, tr.val[1]); + + // Bottom left corner. + output_ptr += -output_x_offset + output_y_offset; + float32x4x2_t bl; + bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]); + bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]); + bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f); + bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f); + vst1q_f32(output_ptr, bl.val[0]); + vst1q_f32(output_ptr + 4, bl.val[1]); + + // Bottom right corner. + output_ptr += output_x_offset; + float32x4x2_t br; + br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]); + br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]); + br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f); + br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f); + br.val[0] = vmulq_n_f32(br.val[0], 0.5f); + br.val[1] = vmulq_n_f32(br.val[1], 0.5f); + vst1q_f32(output_ptr, br.val[0]); + vst1q_f32(output_ptr + 4, br.val[1]); + } + // Handle 4 input channels at a time. + for (; ic <= depth - 4; ic += 4) { + const float* input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)]; + float32x4_t x0y0 = vld1q_f32(input_ptr); + float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset); + float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset); + float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset); + + // Top left corner. + float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)]; + vst1q_f32(output_ptr, x0y0); + + // Top right corner. + output_ptr += output_x_offset; + float32x4_t tr = vaddq_f32(x0y0, x1y0); + tr = vmulq_n_f32(tr, 0.5f); + vst1q_f32(output_ptr, tr); + + // Bottom left corner. + output_ptr += -output_x_offset + output_y_offset; + float32x4_t bl = vaddq_f32(x0y0, x0y1); + bl = vmulq_n_f32(bl, 0.5f); + vst1q_f32(output_ptr, bl); + + // Bottom right corner. + output_ptr += output_x_offset; + float32x4_t br = vaddq_f32(x1y0, x1y1); + br = vmlaq_n_f32(bl, br, 0.5f); + br = vmulq_n_f32(br, 0.5f); + vst1q_f32(output_ptr, br); + } + // Handle one input channel at a time. + for (; ic < depth; ic++) { + const int32 input_offset = Offset(input_dims, ic, x0, y0, batch); + + float x0y0 = input_data[input_offset]; + float x1y0 = input_data[input_offset + input_x_offset]; + float x0y1 = input_data[input_offset + input_y_offset]; + float x1y1 = input_data[input_offset + input_x_offset + input_y_offset]; + + // Top left corner. + const int32 output_offset = Offset(output_dims, ic, x, y, batch); + output_data[output_offset] = x0y0; + + // Top right corner. + output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2; + + // Bottom left corner. + float output = (x0y0 + x0y1) / 2; + output_data[output_offset + output_y_offset] = output; + + // Bottom right corner. + output_data[output_offset + output_x_offset + output_y_offset] = + (output + ((x1y0 + x1y1) / 2)) / 2; + } +#else + for (int ch = 0; ch < depth; ch++) { + const int32 input_offset = Offset(input_dims, ch, x0, y0, batch); + + float x0y0 = input_data[input_offset]; + float x1y0 = input_data[input_offset + input_x_offset]; + float x0y1 = input_data[input_offset + input_y_offset]; + float x1y1 = input_data[input_offset + input_x_offset + input_y_offset]; + + // Top left corner. + const int32 output_offset = Offset(output_dims, ch, x, y, batch); + output_data[output_offset] = x0y0; + + // Top right corner. + output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2; + + // Bottom left corner. + float output = (x0y0 + x0y1) / 2; + output_data[output_offset + output_y_offset] = output; + + // Bottom right corner. + output_data[output_offset + output_x_offset + output_y_offset] = + (output + ((x1y0 + x1y1) / 2)) / 2; + } +#endif +} + +inline void ResizeBilinear2x2(const float* input_data, + const Dims<4>& input_dims, float* output_data, + const Dims<4>& output_dims, int32 batches, + int32 input_height, int32 input_width, + int32 depth, int32 output_height, + int32 output_width) { + for (int b = 0; b < batches; b++) { + for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) { + for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) { + int32 x1 = std::min(x0 + 1, input_width - 1); + int32 y1 = std::min(y0 + 1, input_height - 1); + ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_data, + input_dims, output_data, output_dims); + } + } + } +} + +inline void ResizeBilinearGeneric(const float* input_data, + const Dims<4>& input_dims, float* output_data, + const Dims<4>& output_dims, int32 batches, + int32 input_height, int32 input_width, + int32 depth, int32 output_height, + int32 output_width, float height_scale, + float width_scale) { + memset(output_data, 0, + batches * output_height * output_width * depth * sizeof(float)); + + int32 output_offset = 0; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(input_x); + int32 x1 = std::min(x0 + 1, input_width - 1); + float* output_ptr = &output_data[output_offset]; + + // Run kernel on the 4 corners of the bilinear resize algorithm. + int32 input_offset = Offset(input_dims, 0, x0, y0, b); + float scale = (1 - (input_y - y0)) * (1 - (input_x - x0)); + const float* input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x1, y0, b); + scale = (1 - (input_y - y0)) * (input_x - x0); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x0, y1, b); + scale = (input_y - y0) * (1 - (input_x - x0)); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x1, y1, b); + scale = (input_y - y0) * (input_x - x0); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + output_offset += depth; + } + } + } +} + +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + + // Specialize for 2x2 upsample. + if (output_height == 2 * input_height && output_width == 2 * input_width) { + ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches, + input_height, input_width, depth, output_height, + output_width); + } else { + float height_scale = static_cast(input_height) / output_height; + float width_scale = static_cast(input_width) / output_width; + + ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims, + batches, input_height, input_width, depth, + output_height, output_width, height_scale, + width_scale); + } +} + +template +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("SpaceToBatchND"); + + const int output_batch_size = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_height = block_shape_data[0]; + const int block_shape_width = block_shape_data[1]; + const int padding_top = paddings_data[0]; + const int padding_left = paddings_data[2]; + + for (int out_b = 0; out_b < output_batch_size; ++out_b) { + int input_batch = out_b % input_batch_size; + int shift_w = (out_b / input_batch_size) % block_shape_width; + int shift_h = (out_b / input_batch_size) / block_shape_width; + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); + if (out_h * block_shape_height < padding_top || + out_h * block_shape_height >= padding_top + input_height || + out_w * block_shape_width < padding_left || + out_w * block_shape_width >= padding_left + input_width) { + memset(out, 0, depth * sizeof(T)); + } else { + const T* in = + input_data + + Offset(input_dims, 0, + (out_w * block_shape_width + shift_w) - padding_left, + (out_h * block_shape_height + shift_h) - padding_top, + input_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } + } +} + +template +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BatchToSpaceND"); + + const int output_batch_size = ArraySize(output_dims, 3); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_width = block_shape_data[1]; + const int block_shape_height = block_shape_data[0]; + + for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + int out_batch = in_batch % output_batch_size; + int out_w = in_w * block_shape_width + + (in_batch / output_batch_size) % block_shape_width; + int out_h = in_h * block_shape_height + + (in_batch / output_batch_size) / block_shape_width; + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); + const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } +} + +template +inline void Pad(const T* input_data, const Dims<4>& input_dims, + const std::vector& left_paddings, + const std::vector& right_paddings, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Pad"); + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int left_b_padding = left_paddings[3]; + const int left_h_padding = left_paddings[2]; + const int left_w_padding = left_paddings[1]; + const int left_d_padding = left_paddings[0]; + + const int right_b_padding = right_paddings[3]; + const int right_h_padding = right_paddings[2]; + const int right_w_padding = right_paddings[1]; + const int right_d_padding = right_paddings[0]; + + const int input_depth = ArraySize(input_dims, 0); + + if (left_b_padding != 0) { + memset(output_data, 0, + left_b_padding * output_height * output_width * output_depth * + sizeof(T)); + } + for (int out_b = left_b_padding; out_b < output_batch - right_b_padding; + ++out_b) { + if (left_h_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, 0, out_b), 0, + left_h_padding * output_width * output_depth * sizeof(T)); + } + for (int out_h = left_h_padding; out_h < output_height - right_h_padding; + ++out_h) { + if (left_w_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, out_h, out_b), 0, + left_w_padding * output_depth * sizeof(T)); + } + for (int out_w = left_w_padding; out_w < output_width - right_w_padding; + ++out_w) { + if (left_d_padding != 0) { + memset(output_data + Offset(output_dims, 0, out_w, out_h, out_b), 0, + left_d_padding * sizeof(T)); + } + + T* out = output_data + + Offset(output_dims, left_d_padding, out_w, out_h, out_b); + const T* in = + input_data + Offset(input_dims, 0, out_w - left_w_padding, + out_h - left_h_padding, out_b - left_b_padding); + memcpy(out, in, input_depth * sizeof(T)); + + if (right_d_padding != 0) { + memset( + output_data + Offset(output_dims, output_depth - right_d_padding, + out_w, out_h, out_b), + 0, right_d_padding * sizeof(T)); + } + } + if (right_w_padding != 0) { + memset( + output_data + Offset(output_dims, 0, output_width - right_w_padding, + out_h, out_b), + 0, right_w_padding * output_depth * sizeof(T)); + } + } + if (right_h_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, + output_height - right_h_padding, out_b), + 0, right_h_padding * output_width * output_depth * sizeof(T)); + } + } + if (right_b_padding != 0) { + memset(output_data + + Offset(output_dims, 0, 0, 0, output_batch - right_b_padding), + 0, + right_b_padding * output_height * output_width * output_depth * + sizeof(T)); + } +} + +template +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, + const std::vector& starts, + const std::vector& stops, + const std::vector& strides, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("StridedSlice"); + const int start_b = (begin_mask & 8) ? 0 : starts[3]; + const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3]; + const int start_h = (begin_mask & 4) ? 0 : starts[2]; + const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2]; + const int start_w = (begin_mask & 2) ? 0 : starts[1]; + const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1]; + const int start_d = (begin_mask & 1) ? 0 : starts[0]; + const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0]; + + T* out_ptr = output_data; + if (strides[0] == 0) { + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + const int len = stop_d - start_d; + memcpy(out_ptr, + input_data + Offset(input_dims, start_d, in_w, in_h, in_b), + len * sizeof(T)); + out_ptr += len; + } + } + } + } else { + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } + } +} + +template +inline void Slice(const T* input_data, const Dims<4>& input_dims, + const std::vector& begin, const std::vector& size, + T* output_data, const Dims<4>& output_dims) { + // TODO(dkalenichenko): This op only supports 4D tensors. + TFLITE_DCHECK_EQ(begin.size(), 4); + TFLITE_DCHECK_EQ(size.size(), 4); + const int start_b = begin[3]; + const int stop_b = + size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; + const int start_h = begin[2]; + const int stop_h = + size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2]; + const int start_w = begin[1]; + const int stop_w = + size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1]; + const int start_d = begin[0]; + const int stop_d = + size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; ++in_b) { + for (int in_h = start_h; in_h < stop_h; ++in_h) { + for (int in_w = start_w; in_w < stop_w; ++in_w) { + const int len = stop_d - start_d; + memcpy(out_ptr, + input_data + Offset(input_dims, start_d, in_w, in_h, in_b), + len * sizeof(T)); + out_ptr += len; + } + } + } +} + +template +inline void Mean(const T* input_data, const Dims<4>& input_dims, + const std::vector& reduction_indices, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mean"); + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + + // The current implementation only supports simultaneous reduction over + // width and height. + TFLITE_DCHECK_EQ(reduction_indices.size(), 2); + TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || + (reduction_indices[0] == 2 && reduction_indices[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + float value = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; + } + } + output_data[Offset(output_dims, out_d, 0, 0, out_b)] = + value / (input_width * input_height); + } + } +} + +template +void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input1_data[SubscriptToIndex(desc1, c, x, y, b)] - + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + } + } + } + } +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Sub"); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() - input2_map.array(); + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar - input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() - scalar; + } else { + GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims, + output_data, output_dims); + } +} + +template +void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum"); + auto input1_map = MapAsVector(input1_data, input1_dims); + auto output_map = MapAsVector(output_data, output_dims); + auto min_value = input2_data[0]; + output_map.array() = input1_map.array().min(min_value); +} + +template +void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum"); + auto input1_map = MapAsVector(input1_data, input1_dims); + auto output_map = MapAsVector(output_data, output_dims); + auto max_value = input2_data[0]; + output_map.array() = input1_map.array().max(max_value); +} +} // namespace optimized_ops +} // namespace tflite + +#if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#pragma GCC diagnostic pop +#endif + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..f8be99e82fb8721ced7a3e5da686b20ce241ea2d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ +#define TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ + +// TDOD(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#endif // defined(__ARM_NEON__) || defined(__ARM_NEON) +#endif // USE_NEON + +namespace tflite { +namespace tensor_utils { + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector. +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride); +void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result); +void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC operation, the +// assumption here is that result array is initialized to valid values. +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result); +void NeonVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Dot product of two vectors. +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); +float NeonVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors. +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); +void NeonBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result); +void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void PortableSub1Vector(const float* vector, int v_size, float* result); +void NeonSub1Vector(const float* vector, int v_size, float* result); + +// Clip elements of a vector using a abs_limit value. +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result); +void NeonClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Batch vector initialization with another vector. +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector); + +// Apply sigmoid to elements of a vector. +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result); + +// Apply activation function to elements of a vector. +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result); + +// Copy vector to another vector. +void PortableCopyVector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void PortableZeroVector(float* vector, int v_size); + +// Limit a float input f between +abs_limit and -abs_limit. +float PortableClip(float f, float abs_limit); + +// Shift left a vector in place with v_size size. +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value); +void NeonVectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); +void NeonReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); + +} // namespace tensor_utils +} // namespace tflite + +#endif // TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..98f2e365c5249a6c28673fc185ebec34cc2105b2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -0,0 +1,95 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" + +namespace tflite { + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* right_shift) { + TFLITE_CHECK(double_multiplier >= 0.); + TFLITE_CHECK(double_multiplier < 1.); + if (double_multiplier == 0.) { + *quantized_multiplier = 0; + *right_shift = 0; + return; + } + TFLITE_CHECK(double_multiplier > 0.); + const double q = std::frexp(double_multiplier, right_shift); + *right_shift *= -1; + + auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); + TFLITE_CHECK(q_fixed <= (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + --*right_shift; + } + TFLITE_CHECK_GE(*right_shift, 0); + TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); + *quantized_multiplier = static_cast(q_fixed); +} + +void QuantizeMultiplierGreaterThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift) { + TFLITE_CHECK(double_multiplier > 1.); + const double q = std::frexp(double_multiplier, left_shift); + auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); + TFLITE_CHECK(q_fixed <= (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + ++*left_shift; + } + TFLITE_CHECK_GE(*left_shift, 0); + TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); + *quantized_multiplier = static_cast(q_fixed); +} + +void PreprocessSoftmaxScaling(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, int* left_shift) { + // If the overall multiplier (input and beta) is large, then exp() of an + // input difference of 1 scaled by this will be large. In other words, we + // can cap the multiplier and know that, when it is used, the output will be + // (round to) zero wherever the input is not at the maximum value. + + // If the overall scale is less than one, and input_integer_bits=0, then the + // result is double equivalent of Q0.31 (actually with more precision). Thus + // this generates a Q(input_integer_bits).(31-input_integer_bits) + // representation. + const double input_beta_real_multiplier = std::min( + beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0); + + QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier, + quantized_multiplier, left_shift); +} + +int CalculateInputRadius(int input_integer_bits, int input_left_shift) { + const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) * + (1ll << (31 - input_integer_bits)) / + (1ll << input_left_shift); + // Tighten bound using floor. Suppose that we could use the exact value. + // After scaling the difference, the result would be at the maximum. Thus we + // must ensure that our value has lower magnitude. + return static_cast(std::floor(max_input_rescaled)); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h new file mode 100644 index 0000000000000000000000000000000000000000..efb7191c8deb2a23ea5473ab131d2b6537202765 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ +#define PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ + +#include + +namespace tflite { + +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Restricted to the case where the multiplier < 1 (and non-negative). +void QuantizeMultiplierSmallerThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* right_shift); + +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Restricted to the case where the multiplier > 1. +void QuantizeMultiplierGreaterThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift); + +// This first creates a multiplier in a double equivalent of +// Q(input_integer_bits).(31-input_integer_bits) representation, with extra +// precision in the double's fractional bits. It then splits the result into +// significand and exponent. +void PreprocessSoftmaxScaling(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, int* left_shift); + +// Calculate the largest input that will result in a within-bounds intermediate +// result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words, +// it must not overflow before we reduce the value by multiplication by the +// input multiplier. The negative radius is used as the minimum difference +// in Softmax. +int CalculateInputRadius(int input_integer_bits, int input_left_shift); + +} // namespace tflite + +#endif // PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d6f306e2cbae3c780b3d773638ba46cd2abf02f5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" + +#include +#include + +namespace tflite { +namespace { + +using ::testing::Pair; + +TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { + auto quantize = [](double d) { + int32_t q; + int s; + QuantizeMultiplierSmallerThanOne(d, &q, &s); + return std::pair{q, s}; + }; + + EXPECT_DEATH(quantize(-0.1), ""); + EXPECT_THAT(quantize(0.0), Pair(0, 0)); + EXPECT_THAT(quantize(0.25), Pair(1073741824, 1)); + + // Around 0.5 we can see the change in exponent and how we try hard to + // void hitting max int32. + EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1)); + EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0)); + EXPECT_THAT(quantize(0.50), Pair(1073741824, 0)); + + EXPECT_THAT(quantize(0.75), Pair(1610612736, 0)); + EXPECT_THAT(quantize(1 - 1e-9), Pair(2147483646, 0)); + + // If we get close enough to 1.0 it crashes and dies in one of two ways: + // Either the shift becomes negative or we trigger the 'less-than-one' CHECK. + EXPECT_DEATH(quantize(1 - 1e-15), ""); + EXPECT_DEATH(quantize(1 - 1e-17), ""); + EXPECT_DEATH(quantize(1.0), ""); +} + +TEST(QuantizationUtilTest, QuantizeMultiplierGreaterThanOne) { + auto quantize = [](double d) { + int32_t q; + int s; + QuantizeMultiplierGreaterThanOne(d, &q, &s); + return std::pair{q, s}; + }; + + // If we are close enough to 1.0 it crashes. + EXPECT_DEATH(quantize(1 + 1e-16), ""); + + EXPECT_THAT(quantize(1 + 1e-11), Pair(1073741824, 1)); + EXPECT_THAT(quantize(1.25), Pair(1342177280, 1)); + EXPECT_THAT(quantize(1.50), Pair(1610612736, 1)); + EXPECT_THAT(quantize(1.75), Pair(1879048192, 1)); + + // Around the powers of two we see the change in exponent. Also, + // we try hard to avoid hitting max int32. + EXPECT_THAT(quantize(2 - 1e-9), Pair(2147483647, 1)); + EXPECT_THAT(quantize(2 - 1e-11), Pair(1073741824, 2)); + EXPECT_THAT(quantize(2), Pair(1073741824, 2)); +} + +TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) { + auto quantize = [](double beta, double scale, int integer_bits) { + int32_t q; + int s; + PreprocessSoftmaxScaling(beta, scale, integer_bits, &q, &s); + return std::pair{q, s}; + }; + + // If beta * scale is greater than fits in the number of integer bits, the + // result is move near the maximum. Otherwise they quantize as expected. + // With 4 integer bits we can represent up to 16.0. + EXPECT_THAT(quantize(1.0, 16.0, 4), Pair(2147483647, 31)); + EXPECT_THAT(quantize(1.0, 8.0, 4), Pair(1073741824, 31)); + // But with 5 bits we can go further. + EXPECT_THAT(quantize(2.0, 16.0, 5), Pair(2147483647, 31)); + EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(1073741824, 31)); +} + +TEST(QuantizationUtilTest, CalculateInputRadius) { + EXPECT_EQ(CalculateInputRadius(4, 27), 15); + EXPECT_EQ(CalculateInputRadius(3, 27), 14); + EXPECT_EQ(CalculateInputRadius(3, 28), 7); + EXPECT_EQ(CalculateInputRadius(4, 2), 503316480); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h new file mode 100644 index 0000000000000000000000000000000000000000..8e0f234545e43dd8b2412e065aaecad8325a1182 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int ic = 0; ic < input_depth; ++ic) { + for (int m = 0; m < depth_multiplier; m++) { + const int oc = m + ic * depth_multiplier; + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + float total = 0.f; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + float input_value = + input_data[Offset(input_dims, ic, in_x, in_y, b)]; + float filter_value = filter_data[Offset( + filter_dims, oc, filter_x, filter_y, 0)]; + total += (input_value * filter_value); + } + } + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, oc, 0, 0, 0)]; + } + output_data[Offset(output_dims, oc, out_x, out_y, b)] = + ActivationFunctionWithMinMax(total + bias_value, + output_activation_min, + output_activation_max); + } + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride_width, stride_height, pad_width, pad_height, + depth_multiplier, output_activation_min, output_activation_max, + output_data, output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + float* output_data, const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, + depth_multiplier, output_data, output_dims); +} + +} // end namespace reference_ops +} // end namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h new file mode 100644 index 0000000000000000000000000000000000000000..8a80558b32f2858778460956cd9f57617674e21e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ + +#include + +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int ic = 0; ic < input_depth; ++ic) { + for (int m = 0; m < depth_multiplier; m++) { + const int oc = m + ic * depth_multiplier; + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + int32 input_val = + input_data[Offset(input_dims, ic, in_x, in_y, b)]; + int32 filter_val = filter_data[Offset(filter_dims, oc, + filter_x, filter_y, 0)]; + acc += + (filter_val + filter_offset) * (input_val + input_offset); + } + } + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, oc, out_x, out_y, b)] = + static_cast(acc); + } + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, + stride_height, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, input_offset, filter_data, + filter_dims, filter_offset, bias_data, bias_dims, stride, + stride, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +} // end namespace reference_ops +} // end namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..c5b0bccc9da5fa2ff9c3a9d430725b613435abf1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace tensor_utils { + +float PortableClip(float f, float abs_limit) { + float result = (abs_limit < f) ? abs_limit : f; + result = (-abs_limit > result) ? -abs_limit : result; + return result; +} + +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride) { + float* result_in_batch = result; + for (int b = 0; b < n_batch; b++) { + const float* matrix_ptr = matrix; + for (int r = 0; r < m_rows; r++) { + const float* vector_in_batch = vector + b * m_cols; + for (int c = 0; c < m_cols; c++) { + *result_in_batch += *matrix_ptr++ * *vector_in_batch++; + } + result_in_batch += result_stride; + } + } +} + +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = *vector1++ * *vector2++; + } +} + +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + float result = 0.0; + for (int v = 0; v < v_size; v++) { + result += *vector1++ * *vector2++; + } + return result; +} + +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + float* result_ptr = result; + const float* vector1_ptr = vector1; + const float* vector2_ptr = vector2; + for (int b = 0; b < n_batch; b++) { + *result_ptr = + PortableVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size); + vector1_ptr += v_size; + vector2_ptr += v_size; + result_ptr += result_stride; + } +} + +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + *result++ += *vector1++ * *vector2++; + } +} + +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result) { + for (int b = 0; b < n_batch; b++) { + for (int v = 0; v < v_size; v++) { + *result++ += vector[v] * *batch_vector++; + } + } +} + +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector) { + for (int b = 0; b < n_batch; b++) { + memcpy(batch_vector + b * v_size, vector, v_size * sizeof(float)); + } +} + +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result) { + auto sigmoid_func = ActivationFunctor(kTfLiteActSigmoid); + for (int v = 0; v < v_size; v++) { + *result++ = (sigmoid_func)(*vector++); + } +} + +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result) { + auto activation_func = ActivationFunctor(activation); + for (int v = 0; v < v_size; v++) { + *result++ = (activation_func)(*vector++); + } +} + +void PortableCopyVector(const float* vector, int v_size, float* result) { + memcpy(result, vector, v_size * sizeof(float)); +} + +void PortableSub1Vector(const float* vector, int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = 1.0f - *vector++; + } +} + +void PortableZeroVector(float* vector, int v_size) { + memset(vector, 0, v_size * sizeof(float)); +} + +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = PortableClip(*vector++, abs_limit); + } +} + +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value) { + TF_LITE_ASSERT(v_size > 0); + for (int i = 0; i < v_size - 1; i++) { + vector[i] = vector[i + 1]; + } + vector[v_size - 1] = shift_value; +} + +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + const float* input_vector_ptr = input_vector; + for (int o = 0; o < output_size; o++) { + for (int r = 0; r < reduction_size; r++) { + output_vector[o] += *input_vector_ptr++; + } + } +} + +} // namespace tensor_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..c2ab78000b81485f037c507933cd024e70f39850 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -0,0 +1,189 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ + +// TDOD(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace tensor_utils { + +// Limit a float input f betweeen +abs_limit and -abs_limit. +float PortableClip(float f, float abs_limit); + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector. +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the +// assumption here is that result array is initialized to valid values. +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result); + +// Dot product of two vectors. +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors. +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result); + +// Batch vector initialization with another vector. +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector); + +// Apply sigmoid to elements of a vector. +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result); + +// Apply activation function to elements of a vector. +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result); + +// Copy vector to another vector. +void PortableCopyVector(const float* vector, int v_size, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void PortableSub1Vector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void PortableZeroVector(float* vector, int v_size); + +// Clip elements of a vector using a abs_limit value. +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Shift left a vector in place with v_size size. +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); + +float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } + +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector, + n_batch, result, result_stride); +} + +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + PortableVectorVectorCwiseProduct(vector1, vector2, v_size, result); +} + +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + PortableVectorVectorCwiseProductAccumulate(vector1, vector2, v_size, result); +} + +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result) { + PortableVectorBatchVectorCwiseProductAccumulate(vector, v_size, batch_vector, + n_batch, result); +} + +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + return PortableVectorVectorDotProduct(vector1, vector2, v_size); +} + +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch, + result, result_stride); +} + +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); +} + +void ApplySigmoidToVector(const float* vector, int v_size, float* result) { + PortableApplySigmoidToVector(vector, v_size, result); +} + +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result) { + PortableApplyActivationToVector(vector, v_size, activation, result); +} + +void CopyVector(const float* vector, int v_size, float* result) { + PortableCopyVector(vector, v_size, result); +} + +void Sub1Vector(const float* vector, int v_size, float* result) { + PortableSub1Vector(vector, v_size, result); +} + +void ZeroVector(float* vector, int v_size) { + PortableZeroVector(vector, v_size); +} + +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + PortableClipVector(vector, v_size, abs_limit, result); +} + +void VectorShiftLeft(float* vector, int v_size, float shift_value) { + PortableVectorShiftLeft(vector, v_size, shift_value); +} + +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + PortableReductionSumVector(input_vector, output_vector, output_size, + reduction_size); +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..b9ca3d5c626dff4ea8ba52949e8fea8e9b43689f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -0,0 +1,2455 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( + int32 x, int32 quantized_multiplier, int right_shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + return RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); +} + +inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( + int32 x, int32 quantized_multiplier, int left_shift) { + using gemmlowp::SaturatingRoundingDoublingHighMul; + return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), + quantized_multiplier); +} + +template +int CountLeadingZeros(T integer_input) { + static_assert(std::is_unsigned::value, + "Only unsigned integer types handled."); + const T one_in_leading_positive = static_cast(1) + << (std::numeric_limits::digits - 1); + int leading_zeros = 0; + while (integer_input < one_in_leading_positive) { + integer_input <<= 1; + ++leading_zeros; + } + return leading_zeros; +} + +// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE +// BROADCASTING. +// +// NdArrayDesc describes the shape and memory layout of an N-dimensional +// rectangular array of numbers. +// +// NdArrayDesc is basically identical to Dims defined in types.h. +// However, as Dims is to be deprecated, this class exists as an adaptor +// to enable simple unoptimized implementations of element-wise broadcasting +// operations. +template +struct NdArrayDesc { + // The "extent" of each dimension. Indices along dimension d must be in the + // half-open interval [0, extents[d]). + int extents[N]; + + // The number of *elements* (not bytes) between consecutive indices of each + // dimension. + int strides[N]; +}; + +// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// ELEMENT-WISE BROADCASTING. +// +// Same as Offset(), except takes as NdArrayDesc instead of Dims. +inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, + int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); + return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + + i3 * desc.strides[3]; +} + +// Given the dimensions of the operands for an element-wise binary broadcast, +// adjusts them so that they can be directly iterated over with simple loops. +// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and +// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. +// +// This function assumes that the two input shapes are compatible up to +// broadcasting and the shorter one has already been prepended with 1s to be the +// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), +// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that +// Dims refer to shapes in reverse order. In this case, input0_dims will be +// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). +// +// When two shapes are compatible up to broadcasting, for each dimension d, +// the input extents are either equal, or one of them is 1. +// +// This function performs the following for each dimension d: +// - If the extents are equal, then do nothing since the loop that walks over +// both of the input arrays is correct. +// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 +// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows +// array0 to be referenced *at any index* in dimension d and still access the +// same slice. +template +inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, + const Dims& input1_dims, + NdArrayDesc* desc0_out, + NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + // Copy dims to desc. + for (int i = 0; i < N; ++i) { + desc0_out->extents[i] = input0_dims.sizes[i]; + desc0_out->strides[i] = input0_dims.strides[i]; + desc1_out->extents[i] = input1_dims.sizes[i]; + desc1_out->strides[i] = input1_dims.strides[i]; + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = ArraySize(input0_dims, i); + const int extent1 = ArraySize(input1_dims, i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } +} + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + (void)im2col_data; // only used in optimized code. + (void)im2col_dims; // only used in optimized code. + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); + if (bias_data) { + TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0)); + } + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + float total = 0.f; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + float input_value = input_data[Offset(input_dims, in_channel, + in_x, in_y, batch)]; + float filter_value = + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; + total += (input_value * filter_value); + } + } + } + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; + } + output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(total + bias_value, + output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride_width, + int stride_height, int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, + stride_width, stride_height, pad_width, pad_height, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, output_data, + output_dims, im2col_data, im2col_dims); +} + +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + (void)im2col_data; // only used in optimized code. + (void)im2col_dims; // only used in optimized code. + (void)gemm_context; // only used in optimized code. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = + MatchingArraySize(filter_dims, 3, bias_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + int32 input_val = input_data[Offset(input_dims, in_channel, + in_x, in_y, batch)]; + int32 filter_val = + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; + acc += + (filter_val + filter_offset) * (input_val + input_offset); + } + } + } + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = + static_cast(acc); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, stride_height, + pad_width, pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, stride, pad_width, + pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims, im2col_data, im2col_dims, gemm_context); +} + +template +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int input_batch = ArraySize(input_dims, 3); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_batch = ArraySize(output_dims, 3); + + TFLITE_DCHECK_EQ(input_width * block_size, output_width); + TFLITE_DCHECK_EQ(input_height * block_size, output_height); + TFLITE_DCHECK_EQ(input_depth, output_depth * block_size * block_size); + TFLITE_DCHECK_EQ(input_batch, output_batch); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + const int in_d = + out_d + ((out_h % block_size) * block_size + out_w % block_size) * + output_depth; + const int in_w = out_w / block_size; + const int in_h = out_h / block_size; + const int in_b = out_b; + + const int output_index = + Offset(output_dims, out_d, out_w, out_h, out_b); + const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + + output_data[output_index] = input_data[input_index]; + } + } + } + } +} + +template +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int input_batch = ArraySize(input_dims, 3); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_batch = ArraySize(output_dims, 3); + + TFLITE_DCHECK_EQ(input_width, output_width * block_size); + TFLITE_DCHECK_EQ(input_height, output_height * block_size); + TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth); + TFLITE_DCHECK_EQ(input_batch, output_batch); + + for (int in_b = 0; in_b < input_batch; ++in_b) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + for (int in_d = 0; in_d < input_depth; ++in_d) { + const int out_d = + in_d + ((in_h % block_size) * block_size + in_w % block_size) * + input_depth; + const int out_w = in_w / block_size; + const int out_h = in_h / block_size; + const int out_b = in_b; + + const int output_index = + Offset(output_dims, out_d, out_w, out_h, out_b); + const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + + output_data[output_index] = input_data[input_index]; + } + } + } + } +} + +inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); + const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0); + const int accum_depth = ArraySize(weights_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + float total = 0.f; + for (int d = 0; d < accum_depth; ++d) { + total += input_data[b * accum_depth + d] * + weights_data[out_c * accum_depth + d]; + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; + } + output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax( + total + bias_value, output_activation_min, output_activation_max); + } + } +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, const Dims<4>& weights_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data, + bias_dims, output_activation_min, output_activation_max, + output_data, output_dims); +} + +inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + (void)gemm_context; // only used in optimized code. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0); + const int accum_depth = ArraySize(filter_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + int32 acc = 0; + for (int d = 0; d < accum_depth; ++d) { + int32 input_val = input_data[b * accum_depth + d]; + int32 filter_val = filter_data[out_c * accum_depth + d]; + acc += (filter_val + filter_offset) * (input_val + input_offset); + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier, + output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[out_c + output_depth * b] = static_cast(acc); + } + } +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims, gemm_context); +} + +template +void NonGlobalBatchNormalization( + const float* input_data, const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, const float* multiplier_data, + const Dims<4>& multiplier_dims, const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2, + offset_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1, + offset_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, x, y, 0)]) * + multiplier_data[Offset(multiplier_dims, c, x, y, 0)] + + offset_data[Offset(offset_dims, c, x, y, 0)]); + } + } + } + } +} + +template +void GlobalBatchNormalization(const float* input_data, + const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, + const float* multiplier_data, + const Dims<4>& multiplier_dims, + const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, 0, 0, 0)]) * + multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] + + offset_data[Offset(offset_dims, c, 0, 0, 0)]); + } + } + } + } +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float lower = 0; + float clamped = val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 1; + const float lower = -1; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 6; + const float lower = 0; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone, ""); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + float squared_l2_norm = 0; + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + squared_l2_norm += val * val; + } + float l2_norm = std::sqrt(squared_l2_norm); + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] / l2_norm; + } + } + } + } +} + +inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, + int* output_shift) { + *output_shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*output_shift; + } + TFLITE_DCHECK_GT(input, 0); + const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; + const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; + *output_shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + TFLITE_DCHECK_GE(input, (1 << 27)); + TFLITE_DCHECK_LT(input, (1 << 29)); + using gemmlowp::FixedPoint; + using gemmlowp::Rescale; + using gemmlowp::SaturatingRoundingMultiplyByPOT; + // Using 3 integer bits gives us enough room for the internal arithmetic in + // this Newton-Raphson iteration. + using F3 = FixedPoint; + using F0 = FixedPoint; + const F3 fixedpoint_input = F3::FromRaw(input >> 1); + const F3 fixedpoint_half_input = + SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); + const F3 fixedpoint_half_three = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); + // Newton-Raphson iteration + // Naive unoptimized starting guess: x = 1 + F3 x = F3::One(); + // Naive unoptimized number of iterations: 5 + for (int i = 0; i < 5; i++) { + const F3 x3 = Rescale<3>(x * x * x); + x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); + } + const F0 fixedpoint_half_sqrt_2 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); + x = x * fixedpoint_half_sqrt_2; + *output_inv_sqrt = x.raw(); + if (*output_shift < 0) { + *output_inv_sqrt <<= -*output_shift; + *output_shift = 0; + } +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK_EQ(batches, 1); + TFLITE_DCHECK_EQ(height, 1); + TFLITE_DCHECK_EQ(width, 1); + int32 square_l2_norm = 0; + for (int i = 0; i < depth; i++) { + int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int i = 0; i < depth; i++) { + int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + output_data[Offset(output_dims, i, 0, 0, 0)] = + static_cast(output_val); + } +} + +inline void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] + + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const int32 input1_val = + input1_offset + input1_data[Offset(input1_dims, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[Offset(input2_dims, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +template +void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, + input1_multiplier, input1_shift, input2_data, input2_dims, + input2_offset, input2_multiplier, input2_shift, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] * + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +template +void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest + // stride, typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for + // the best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest + // stride, typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for + // the best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 unclamped_result = + output_offset + + MultiplyByQuantizedMultiplierSmallerThanOne( + input1_val * input2_val, output_multiplier, output_shift); + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, unclamped_result)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + +template +void Concatenation(int concat_dim, const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + TFLITE_DCHECK_GT(inputs_count, 1); + int concat_size = 0; + for (int i = 0; i < inputs_count; i++) { + for (int j = 0; j < 4; j++) { + if (j != concat_dim) { + MatchingArraySize(*input_dims[i], j, output_dims, j); + } + } + concat_size += ArraySize(*input_dims[i], concat_dim); + } + TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + int outer_size = 1; + for (int i = concat_dim + 1; i < 4; i++) { + outer_size *= output_dims.sizes[i]; + } + Scalar* output_ptr = output_data; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < inputs_count; ++i) { + const int copy_size = + input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; + memcpy(output_ptr, input_data[i] + k * copy_size, + copy_size * sizeof(Scalar)); + output_ptr += copy_size; + } + } +} + +template +void DepthConcatenation(const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + Concatenation(0, input_data, input_dims, inputs_count, + output_data, output_dims); +} + +inline void LstmCell(const float* input_data, const Dims<4>& input_dims, + const float* prev_activ_data, + const Dims<4>& prev_activ_dims, const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, const float* prev_state_data, + const Dims<4>& prev_state_dims, float* output_state_data, + const Dims<4>& output_state_dims, float* output_activ_data, + const Dims<4>& output_activ_dims, float* concat_temp_data, + const Dims<4>& concat_temp_dims, float* activ_temp_data, + const Dims<4>& activ_temp_dims) { + const int batches = + MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, + output_state_dims, 3, output_activ_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, + output_state_dims, 2, output_activ_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, + output_state_dims, 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + + // Concatenate prev_activ and input data together + std::vector concat_input_arrays_data; + std::vector const*> concat_input_arrays_dims; + concat_input_arrays_data.push_back(input_data); + concat_input_arrays_data.push_back(prev_activ_data); + concat_input_arrays_dims.push_back(&input_dims); + concat_input_arrays_dims.push_back(&prev_activ_dims); + Concatenation( + 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]), + concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims); + + // Fully connected + FullyConnected( + concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data, + bias_dims, activ_temp_data, activ_temp_dims); + + // Memory state update (the LSTM "guts") + for (int b = 0; b < batches; ++b) { + for (int w = 0; w < width; ++w) { + for (int h = 0; h < height; ++h) { + for (int c = 0; c < output_depth; ++c) { + const float input_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 0 * output_depth + c, w, h, b)])); + const float new_input = std::tanh(activ_temp_data[Offset( + activ_temp_dims, 1 * output_depth + c, w, h, b)]); + const float forget_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 2 * output_depth + c, w, h, b)])); + const float output_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 3 * output_depth + c, w, h, b)])); + const float new_state = + input_gate * new_input + + forget_gate * + prev_state_data[Offset(prev_state_dims, c, w, h, b)]; + output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state; + output_activ_data[Offset(output_activ_dims, c, w, h, b)] = + output_gate * std::tanh(new_state); + } + } + } + } +} + +template +void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, + int outputs_count, Scalar* const* output_data, + const Dims<4>* const* output_dims) { + TFLITE_DCHECK_GE(outputs_count, 1); + for (int i = 0; i < outputs_count; i++) { + /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3); + /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); + /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); + } + const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3); + const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2); + const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1); + // for now we dont have a model with a TensorFlowSplit + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + int in_c = 0; + for (int i = 0; i < outputs_count; ++i) { + const int depth = ArraySize(*output_dims[i], 0); + for (int c = 0; c < depth; ++c) { + output_data[i][Offset(*output_dims[i], c, x, y, b)] = + input_data[Offset(input_dims, in_c, x, y, b)]; + in_c++; + } + } + TFLITE_DCHECK(in_c == ArraySize(input_dims, 0)); + } + } + } +} + +// TODO(benoitjacob) make this a proper reference impl without Eigen! +template +using MatrixMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithLastDimAsCols(Scalar* data, + const Dims& dims) { + const int cols = dims.sizes[N - 1]; + int rows = 1; + for (int d = 0; d < N - 1; d++) { + rows *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +inline int NodeOffset(int b, int h, int w, int height, int width) { + return (b * height + h) * width + w; +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float total = 0.f; + float filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + total += + input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + filter_count++; + } + } + const float average = total / filter_count; + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(average, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + int32 acc = 0; + int filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + filter_count++; + } + } + acc = (acc + filter_count / 2) / filter_count; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + static_cast(acc); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float sum_squares = 0.f; + int filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + const float val = + input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + sum_squares += val * val; + filter_count++; + } + } + const float l2pool_result = std::sqrt(sum_squares / filter_count); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(l2pool_result, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float max = std::numeric_limits::lowest(); + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + max = std::max( + max, + input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + } + } + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(max, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LE(output_activation_max, 255); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + uint8 max = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + max = std::max( + max, + input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + } + } + max = std::max(max, output_activation_min); + max = std::min(max, output_activation_max); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + static_cast(max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void LocalResponseNormalization(const float* input_data, + const Dims<4>& input_dims, int range, + float bias, float alpha, float beta, + float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const int begin_input_c = std::max(0, c - range); + const int end_input_c = std::min(depth, c + range); + float accum = 0.f; + for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) { + const float input_val = + input_data[Offset(input_dims, input_c, x, y, b)]; + accum += input_val * input_val; + } + const float multiplier = std::pow(bias + alpha * accum, -beta); + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] * multiplier; + } + } + } + } +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + // Find max element value which we'll use to ensure numerical stability + // taking advantage of the following equality: + // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C)) + float max = std::numeric_limits::lowest(); + for (int c = 0; c < depth; ++c) { + max = std::max(max, input_data[Offset(input_dims, c, x, y, b)]); + } + + // Compute sum. + float sum = 0.f; + for (int c = 0; c < depth; ++c) { + sum += std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) * + beta); + } + + // Compute result. + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) * + beta) / + sum; + } + } + } + } +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + // The representation chosen for the input to the exp() function is Q5.26. + // We need to leave extra space since values that we skip might be as large as + // -32 before multiplying by input_beta_multiplier, and therefore as large as + // -16 afterwards. Note that exp(-8) is definitely not insignificant to + // accumulation, but exp(-16) definitely is. + static const int kScaledDiffIntegerBits = 5; + static const int kAccumulationIntegerBits = 12; + using FixedPointScaledDiff = + gemmlowp::FixedPoint; + using FixedPointAccum = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int x = 0; x < width; ++x) { + for (int y = 0; y < height; ++y) { + uint8 max_in_row = 0; + for (int c = 0; c < depth; ++c) { + max_in_row = + std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]); + } + + FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + sum_of_exps = + sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); + } + } + + int32 fixed_sum_of_exps = sum_of_exps.raw(); + int headroom_plus_one = + CountLeadingZeros(static_cast(fixed_sum_of_exps)); + // This is the number of bits to the left of the binary point above 1.0. + // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and + // no later adjustment will be needed. + int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; + int32 shifted_sum_minus_one = static_cast( + (static_cast(fixed_sum_of_exps) << headroom_plus_one) - + (static_cast(1) << 31)); + + FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( + FixedPoint0::FromRaw(shifted_sum_minus_one)); + + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + + FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); + int32 unsat_output = gemmlowp::RoundingDivideByPOT( + (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + + output_data[Offset(output_dims, c, x, y, b)] = static_cast( + std::max(std::min(unsat_output, static_cast(255)), 0)); + + } else { + output_data[Offset(output_dims, c, x, y, b)] = 0; + } + } + } + } + } +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + float result = 1.f / (1.f + std::exp(-val)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered <= -input_range_radius) { + output_val = 0; + } else if (input_val_centered >= input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = + FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[Offset(output_dims, c, x, y, b)] = output_val; + } + } + } + } +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + float result = std::tanh(val); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, + int32 zero_point, double scale, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int32 val = input_data[Offset(input_dims, c, x, y, b)]; + float result = static_cast(scale * (val - zero_point)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, + float rmin, float rmax, float* output_data, + const Dims<4>& output_dims) { + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + TFLITE_DCHECK_LE(rmin, 0.); + TFLITE_DCHECK_GE(rmax, 0.); + + // Determine quantization parameters: zero_point, scale. + using Integer = uint8; + const Integer qmin = std::numeric_limits::min(); + const Integer qmax = std::numeric_limits::max(); + const float qmin_float = qmin; + const float qmax_float = qmax; + int32 zero_point = 0; + float scale = 0.f; + // If rmin==rmax, both must be zero per the above assertion, + // so we are done. + if (rmin != rmax) { + // First determine the scale. + scale = (rmax - rmin) / (qmax_float - qmin_float); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const float zero_point_from_min = qmin_float - rmin / scale; + const float zero_point_from_max = qmax_float - rmax / scale; + const float zero_point_from_min_error = + std::abs(qmin_float) + std::abs(rmin / scale); + const float zero_point_from_max_error = + std::abs(qmax_float) + std::abs(rmax / scale); + + const float zero_point_float = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + if (zero_point_float < qmin_float) { + zero_point = qmin; + } else if (zero_point_float > qmax_float) { + zero_point = qmax; + } else { + zero_point = static_cast(TfLiteRound(zero_point_float)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + TFLITE_DCHECK_GE(zero_point, qmin); + TFLITE_DCHECK_LE(zero_point, qmax); + } + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const float src_val = input_data[Offset(input_dims, c, x, y, b)]; + const float unclamped_quantized_val = + TfLiteRound(zero_point + src_val / scale); + const float quantized_val = std::min( + qmax_float, std::max(qmin_float, unclamped_quantized_val)); + const float dst_val = scale * (quantized_val - zero_point); + output_data[Offset(output_dims, c, x, y, b)] = dst_val; + } + } + } + } +} + +template +inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, + DstT* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int offset = Offset(input_dims, c, x, y, b); + output_data[offset] = static_cast(input_data[offset]); + } + } + } + } +} + +inline void Floor(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int offset = Offset(input_dims, c, x, y, b); + output_data[offset] = std::floor(input_data[offset]); + } + } + } + } +} + +template +inline void Gather(const T* input_data, const Dims<4>& input_dims, + int input_rank, const int32* coords_data, + const Dims<4>& coords_dims, T* output_data, + const Dims<4>& output_dims) { + TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); + int stride = input_dims.strides[input_rank - 1]; + T* out = output_data; + + for (int i = 0; i < coords_dims.sizes[0]; i++) { + TFLITE_DCHECK_GE(coords_data[i], 0); + TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); + const T* in = input_data + coords_data[i] * stride; + memcpy(out, in, sizeof(T) * stride); + out += stride; + } +} + +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + float height_scale = static_cast(input_height) / output_height; + float width_scale = static_cast(input_width) / output_width; + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(std::floor(input_x)); + int32 x1 = std::min(x0 + 1, input_width - 1); + for (int c = 0; c < depth; ++c) { + float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] * + (1 - (input_y - y0)) * + (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x0, y1, b)] * + (input_y - y0) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x1, y0, b)] * + (1 - (input_y - y0)) * (input_x - x0) + + input_data[Offset(input_dims, c, x1, y1, b)] * + (input_y - y0) * (input_x - x0); + output_data[Offset(output_dims, c, x, y, b)] = interpolation; + } + } + } + } +} + +template +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + const int output_batch_size = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_height = block_shape_data[0]; + const int block_shape_width = block_shape_data[1]; + const int padding_top = paddings_data[0]; + const int padding_left = paddings_data[2]; + + for (int out_b = 0; out_b < output_batch_size; ++out_b) { + int input_batch = out_b % input_batch_size; + int shift_w = (out_b / input_batch_size) % block_shape_width; + int shift_h = (out_b / input_batch_size) / block_shape_width; + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); + if (out_h * block_shape_height < padding_top || + out_h * block_shape_height >= padding_top + input_height || + out_w * block_shape_width < padding_left || + out_w * block_shape_width >= padding_left + input_width) { + memset(out, 0, depth * sizeof(T)); + } else { + const T* in = + input_data + + Offset(input_dims, 0, + (out_w * block_shape_width + shift_w) - padding_left, + (out_h * block_shape_height + shift_h) - padding_top, + input_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } + } +} + +template +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, T* output_data, + const Dims<4>& output_dims) { + const int output_batch_size = ArraySize(output_dims, 3); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_width = block_shape_data[1]; + const int block_shape_height = block_shape_data[0]; + + for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + int out_batch = in_batch % output_batch_size; + int out_w = in_w * block_shape_width + + (in_batch / output_batch_size) % block_shape_width; + int out_h = in_h * block_shape_height + + (in_batch / output_batch_size) / block_shape_width; + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); + const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } +} + +template +inline void Pad(const T* input_data, const Dims<4>& input_dims, + const std::vector& left_paddings, + const std::vector& right_paddings, T* output_data, + const Dims<4>& output_dims) { + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int left_b_padding = left_paddings[3]; + const int left_h_padding = left_paddings[2]; + const int left_w_padding = left_paddings[1]; + const int left_d_padding = left_paddings[0]; + + const int right_b_padding = right_paddings[3]; + const int right_h_padding = right_paddings[2]; + const int right_w_padding = right_paddings[1]; + const int right_d_padding = right_paddings[0]; + + const T* in_ptr = input_data; + T* out_ptr = output_data; + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + if (out_b < left_b_padding || + out_b >= output_batch - right_b_padding || + out_h < left_h_padding || + out_h >= output_height - right_h_padding || + out_w < left_w_padding || + out_w >= output_width - right_w_padding || + out_d < left_d_padding || + out_d >= output_depth - right_d_padding) { + *out_ptr++ = 0; + } else { + *out_ptr++ = *in_ptr++; + } + } + } + } + } +} + +template +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, + const std::vector& starts, + const std::vector& stops, + const std::vector& strides, T* output_data, + const Dims<4>& output_dims) { + const int start_b = (begin_mask & 8) ? 0 : starts[3]; + const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3]; + const int start_h = (begin_mask & 4) ? 0 : starts[2]; + const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2]; + const int start_w = (begin_mask & 2) ? 0 : starts[1]; + const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1]; + const int start_d = (begin_mask & 1) ? 0 : starts[0]; + const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } +} + +template +inline void Slice(const T* input_data, const Dims<4>& input_dims, + const std::vector& begin, const std::vector& size, + T* output_data, const Dims<4>& output_dims) { + // TODO(dkalenichenko): This op only supports 4D tensors. + TFLITE_DCHECK_EQ(begin.size(), 4); + TFLITE_DCHECK_EQ(size.size(), 4); + const int start_b = begin[3]; + const int stop_b = + size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; + const int start_h = begin[2]; + const int stop_h = + size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2]; + const int start_w = begin[1]; + const int stop_w = + size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1]; + const int start_d = begin[0]; + const int stop_d = + size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; ++in_b) { + for (int in_h = start_h; in_h < stop_h; ++in_h) { + for (int in_w = start_w; in_w < stop_w; ++in_w) { + for (int in_d = start_d; in_d < stop_d; ++in_d) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } +} + +template +inline void Mean(const T* input_data, const Dims<4>& input_dims, + const std::vector& reduction_indices, T* output_data, + const Dims<4>& output_dims) { + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + + // The current implementation only supports simultaneous reduction over + // width and height. + TFLITE_DCHECK_EQ(reduction_indices.size(), 2); + TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || + (reduction_indices[0] == 2 && reduction_indices[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + float value = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; + } + } + output_data[Offset(output_dims, out_d, 0, 0, out_b)] = + value / (input_width * input_height); + } + } +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input1_data[SubscriptToIndex(desc1, c, x, y, b)] - + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + } + } + } + } +} + +template +void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + int batches = MatchingArraySize(input1_dims, 3, output_dims, 3); + int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2); + int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1); + int depth = MatchingArraySize(input1_dims, 0, output_dims, 0); + + auto min_value = input2_data[0]; + + for (int b = 0; b < batches; b++) { + for (int y = 0; y < input_height; y++) { + for (int x = 0; x < input_width; x++) { + for (int c = 0; c < depth; c++) { + int offset = Offset(input1_dims, c, x, y, b); + output_data[offset] = + input1_data[offset] > min_value ? min_value : input1_data[offset]; + } + } + } + } +} + +template +void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + int batches = MatchingArraySize(input1_dims, 3, output_dims, 3); + int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2); + int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1); + int depth = MatchingArraySize(input1_dims, 0, output_dims, 0); + + auto max_value = input2_data[0]; + + for (int b = 0; b < batches; b++) { + for (int y = 0; y < input_height; y++) { + for (int x = 0; x < input_width; x++) { + for (int c = 0; c < depth; c++) { + int offset = Offset(input1_dims, c, x, y, b); + output_data[offset] = + input1_data[offset] < max_value ? max_value : input1_data[offset]; + } + } + } + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/round.h b/tensorflow/contrib/lite/kernels/internal/round.h new file mode 100644 index 0000000000000000000000000000000000000000..38525b0e208b852343849096ac68cbfc9ef3e389 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/round.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ + +#include + +namespace tflite { + +// TODO(aselle): See if we can do this only on jdk. Also mikecase, check +// if you need this for java host build. +#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) +template +inline float TfLiteRound(const float x) { + return ::round(x); +} +inline double TfLiteRound(const double x) { return ::round(x); } +#else +template +inline T TfLiteRound(const T x) { + return std::round(x); +} +#endif + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..ee4111e0416560d94d513c528971bdf3bf819662 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +template +inline T* GetTensorData(TfLiteTensor* tensor); + +template <> +inline float* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.f : nullptr; +} + +template <> +inline uint8_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.uint8 : nullptr; +} + +template <> +inline int32_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i32 : nullptr; +} + +template <> +inline int64_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? reinterpret_cast(tensor->data.raw) + : nullptr; +} + +inline int RemapDim(int max_dimensions, int d) { + return max_dimensions - d - 1; +} + +// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object +// even if the original tensors were not 4D. We should consider rewriting them +// to take a more generic 'shape' object. +inline Dims<4> GetTensorDims(const int data[], const int size) { + Dims<4> d; + for (int i = 0; i < 4; ++i) { + int src = size - i - 1; + if (src >= 0) { + d.sizes[i] = data[src]; + } else { + d.sizes[i] = 1; + } + } + d.strides[0] = 1; + for (int i = 1; i < 4; i++) { + d.strides[i] = d.strides[i - 1] * d.sizes[i - 1]; + } + return d; +} + +inline Dims<4> GetTensorDims(std::vector data) { + return GetTensorDims(data.data(), data.size()); +} + +inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return Dims<4>(); + } + + auto* dims = tensor->dims; + return GetTensorDims(dims->data, dims->size); +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf2068d320f65cf0195abbc181f4ef4ff8f20679 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include +#include + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +TEST(TensorTest, GetTensorDims4D) { + Dims<4> d = GetTensorDims({2, 3, 4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +} + +TEST(TensorTest, GetTensorDims3D) { + Dims<4> d = GetTensorDims({3, 4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +} + +TEST(TensorTest, GetTensorDims2D) { + Dims<4> d = GetTensorDims({4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20)); +} + +TEST(TensorTest, GetTensorDims1D) { + Dims<4> d = GetTensorDims({5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..904a97803a6a9ba369c1e64c711b12d19ffc10c4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc @@ -0,0 +1,27 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#endif // defined(__ARM_NEON__) || defined(__ARM_NEON) +#endif // USE_NEON + +#ifdef USE_NEON +#include "tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h" +#else +#include "tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h" +#endif // USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..0e69ef5982f01e364d865684652d1dfecab6fee3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace tensor_utils { + +// Limit a float input f betweeen +abs_limit and -abs_limit. +float Clip(float f, float abs_limit); + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector using a stride value provided in result_stride. 'result_stride' shows +// how the number of elements between consecutive result values. For example +// result_stride = 1, will cause the output to look like this: +// [O_1, 0_2, ... O_rows] in memory, but result_stride = 3, will cause it to be +// arranged like this in memory: [O_1, x, x, 0_2, x, x, ..., O_rows] +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the +// assumption here is that result array is initialized to valid values. +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Dot product of two vectors. +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors of size n_batch * v_size: +// vector1 = [x_1_1, x_1_2, ..., x_1_vsize, +// x_2_1, x_2_2, ..., x_2_vsize, +// ... +// x_nbatch_1,..., x_nbatch_vsize] +// vector2 = [y_1_1, y_1_2, ..., y_1_vsize, +// y_2_1, y_2_2, ..., y_2_vsize, +// ... +// y_nbatch_1,..., y_nbatch_vsize] +// Then result will be a vector of n_batch size which will be saved with a +// stride of result_stride in memory starting from 'result': +// [x_1_1 * y_1_1 + x_1_2 * y_1_2 + ... + x_1_vsize * y_1_vsize, +// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize, +// ... +// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize] +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result); + +// Batch vector initialization with another vector. +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector); + +// Apply sigmoid to elements of a vector. +void ApplySigmoidToVector(const float* vector, int v_size, float* result); + +// Apply activation function to elements of a vector. +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result); + +// Copy vector to another vector. +void CopyVector(const float* vector, int v_size, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void Sub1Vector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void ZeroVector(float* vector, int v_size); + +// Clip elements of a vector using a abs_limit value. +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Shift left a vector in place with v_size size. +void VectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..588f1a428b8c84367d659c2c5bb59a411cd8bb34 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -0,0 +1,192 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" + +namespace tflite { +namespace tensor_utils { + +TEST(uKernels, ClipTest) { + constexpr int kVectorSize = 10; + constexpr float kAbsLimit = 2.0; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + std::vector output(kVectorSize); + ClipVector(input, kVectorSize, kAbsLimit, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0}))); +} + +TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) { + constexpr int kRow = 3; + constexpr int kCol = 4; + constexpr int kBatch = 2; + static float matrix[kRow * kCol] = {1.0, 2.0, 3.0, 4.0, // + -1.0, -2.0, -3.0, -4.0, // + 1.0, -2.0, 3.0, -4.0}; + static float vector[kCol * kBatch] = {1.0, -1.0, 1.0, -1.0, // + 2.0, -2.0, 2.0, -2.0}; + std::vector output(kRow * kBatch); + std::fill(output.begin(), output.end(), 3.0); + MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch, + output.data(), /*result_stride=*/1); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({1., 5., 13., // + -1., 7., 23.}))); + + std::vector output_with_stride2(kRow * kBatch * 2); + std::fill(output_with_stride2.begin(), output_with_stride2.end(), 3.0); + MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch, + output_with_stride2.data(), + /*result_stride=*/2); + EXPECT_THAT(output_with_stride2, + ElementsAreArray(ArrayFloatNear({1., 3., 5., 3., 13., 3., // + -1., 3., 7., 3., 23., 3.}))); +} + +TEST(uKernels, VectorVectorCwiseProductTest) { + constexpr int kVectorSize = 10; + static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kVectorSize); + VectorVectorCwiseProduct(input1, input2, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45}))); +} + +TEST(uKernels, VectorVectorCwiseProductAccumulateTest) { + constexpr int kVectorSize = 10; + static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kVectorSize); + std::fill(output.begin(), output.end(), 1.0); + VectorVectorCwiseProductAccumulate(input1, input2, kVectorSize, + output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45}))); +} + +TEST(uKernels, VectorBatchVectorAssignTest) { + constexpr int kVectorSize = 5; + constexpr int kBatchSize = 3; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize * kBatchSize); + VectorBatchVectorAssign(input, kVectorSize, kBatchSize, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.0, -0.5, 1.0, -1.5, 2.0, 0.0, -0.5, 1.0, -1.5, 2.0, + 0.0, -0.5, 1.0, -1.5, 2.0}))); +} + +TEST(uKernels, ApplySigmoidToVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + ApplySigmoidToVector(input, kVectorSize, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.5, 0.377541, 0.731059, 0.182426, 0.880797}))); +} + +TEST(uKernels, ApplyActivationToVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + ApplyActivationToVector(input, kVectorSize, kTfLiteActRelu, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 0.0, 2.0}))); + + ApplyActivationToVector(input, kVectorSize, kTfLiteActTanh, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.0, -0.462117, 0.761594, -0.905148, 0.964028}))); +} + +TEST(uKernels, CopyVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + CopyVector(input, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, -0.5, 1.0, -1.5, 2.0}))); +} + +TEST(uKernels, Sub1VectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + Sub1Vector(input, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({1.0, 1.5, 0.0, 2.5, -1.0}))); +} + +TEST(uKernels, ZeroVectorTest) { + constexpr int kVectorSize = 5; + std::vector output(kVectorSize); + ZeroVector(output.data(), kVectorSize); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, 0.0, 0.0, 0.0, 0.0}))); +} + +TEST(uKernels, BatchVectorBatchVectorDotProductTest) { + constexpr int kVectorSize = 5; + constexpr int kBatch = 2; + static float input1[kVectorSize * kBatch] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize * kBatch] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kBatch); + BatchVectorBatchVectorDotProduct(input1, input2, kVectorSize, kBatch, + output.data(), /*result_stride=*/1); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({0.5, 1.75}))); +} + +TEST(uKernels, VectorShiftLeftTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector result(kVectorSize); + VectorShiftLeft(input, kVectorSize, 3.0); + result.assign(input, input + kVectorSize); + EXPECT_THAT(result, + ElementsAreArray(ArrayFloatNear({-0.5, 1.0, -1.5, 2.0, 3.0}))); +} + +TEST(uKernels, ReductionSumVectorTest) { + constexpr int kInputVectorSize = 10; + constexpr int kOutputVectorSize1 = 5; + constexpr int kReductionSize1 = 2; + static float input[kInputVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + 0.0, -0.5, 1.0, 1.0, 2.0}; + std::vector result1(kOutputVectorSize1); + ReductionSumVector(input, result1.data(), kOutputVectorSize1, + kReductionSize1); + EXPECT_THAT(result1, + ElementsAreArray(ArrayFloatNear({-0.5, -0.5, 2.0, 0.5, 3.0}))); + + constexpr int kOutputVectorSize2 = 2; + constexpr int kReductionSize2 = 5; + std::vector result2(kOutputVectorSize2); + ReductionSumVector(input, result2.data(), kOutputVectorSize2, + kReductionSize2); + EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5}))); +} + +} // namespace tensor_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h new file mode 100644 index 0000000000000000000000000000000000000000..07f1cb40045fff3ae47ed4efa6ec43b0cb88a0a7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { + +enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; + +template +struct Dims { + int sizes[N]; + int strides[N]; +}; + +inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]); + return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] + + i3 * dims.strides[3]; +} + +// Get array size, DCHECKing that the dim index is in range. +template +int ArraySize(const Dims& array, int index) { + TFLITE_DCHECK(index >= 0 && index < N); + return array.sizes[index]; +} + +// Get common array size, DCHECKing that they all agree. +template +int MatchingArraySize(const ArrayType1& array1, int index1, + const ArrayType2& array2, int index2) { + TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); + return ArraySize(array1, index1); +} + +template +int MatchingArraySize(const ArrayType1& array1, int index1, + const ArrayType2& array2, int index2, Args... args) { + TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); + return MatchingArraySize(array1, index1, args...); +} + +inline int RequiredBufferSizeForDims(const Dims<4>& dims) { + int max_offset = 0; + for (int i = 0; i < 4; i++) { + max_offset += (dims.sizes[i] - 1) * dims.strides[i]; + } + return max_offset + 1; +} + +template +bool IsPackedWithoutStrides(const Dims& dims) { + int expected_stride = 1; + for (int d = 0; d < N; d++) { + if (dims.strides[d] != expected_stride) return false; + expected_stride *= dims.sizes[d]; + } + return true; +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0546c00cf977af5f722a802866448b0cb293b8d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include +#include +#include "tensorflow/contrib/lite/kernels/internal/round.h" + +namespace tflite { + +TfLiteStatus GetQuantizedConvolutionMultipler( + TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output, double* multiplier) { + const double input_product_scale = input->params.scale * filter->params.scale; + const double bias_scale = bias->params.scale; + const double output_scale = output->params.scale; + + // TODO(ahentz): The following conditions must be guaranteed by the training + // pipeline. + TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <= + 1e-6 * std::min(input_product_scale, bias_scale)); + TF_LITE_ENSURE(context, input_product_scale >= 0); + TF_LITE_ENSURE(context, input_product_scale < output_scale); + + *multiplier = input_product_scale / output_scale; + + return kTfLiteOk; +} + +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + + const auto scale = output->params.scale; + const auto zero_point = output->params.zero_point; + + auto quantize = [scale, zero_point](float f) { + return zero_point + static_cast(TfLiteRound(f / scale)); + }; + + if (activation == kTfLiteActRelu) { + *act_min = std::max(qmin, quantize(0.0)); + *act_max = qmax; + } else if (activation == kTfLiteActRelu6) { + *act_min = std::max(qmin, quantize(0.0)); + *act_max = std::min(qmax, quantize(6.0)); + } else if (activation == kTfLiteActRelu1) { + *act_min = std::max(qmin, quantize(-1.0)); + *act_max = std::min(qmax, quantize(1.0)); + } else { + *act_min = qmin; + *act_max = qmax; + } +} + +void CalculateActivationRangeFloat(TfLiteFusedActivation activation, + float* activation_min, + float* activation_max) { + if (activation == kTfLiteActRelu) { + *activation_min = 0.f; + *activation_max = std::numeric_limits::max(); + } else if (activation == kTfLiteActRelu6) { + *activation_min = 0.f; + *activation_max = 6.f; + } else if (activation == kTfLiteActRelu1) { + *activation_min = -1.f; + *activation_max = 1.f; + } else { + *activation_min = std::numeric_limits::lowest(); + *activation_max = std::numeric_limits::max(); + } +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h new file mode 100644 index 0000000000000000000000000000000000000000..25556ae4567aca45b3bfe4ba02b1cb58331d239d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; } +inline int SizeOfDimension(const TfLiteTensor* t, int dim) { + return t->dims->data[dim]; +} +inline TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node, + int index) { + return &context->tensors[node->inputs->data[index]]; +} +inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node, + int index) { + return &context->tensors[node->outputs->data[index]]; +} +inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; } +inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; } + +inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, + const TfLiteNode* node, int index) { + const bool use_tensor = node->inputs->data[index] != kOptionalTensor; + if (use_tensor) { + return &context->tensors[node->inputs->data[index]]; + } + return nullptr; +} + +// Calculates the multiplication factor for a quantized convolution (or +// quantized depthwise convolution) involving the given tensors. Returns an +// error if the scales of the tensors are not compatible. +TfLiteStatus GetQuantizedConvolutionMultipler( + TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output, double* multiplier); + +// Calculates the useful range of an activation layer given its activation +// tensor. +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max); +void CalculateActivationRangeFloat(TfLiteFusedActivation activation, + float* activation_min, + float* activation_max); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..f43aa372b6398a38e57dd38f3d7c7db2bd3aefc1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace l2norm { + +// This file has two implementation of L2Norm. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(ahentz): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + // TODO(ahentz): Our current implementations only support float32. + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + // TODO(ahentz): For some reason our implementations don't support + // activations. + TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = input->dims->data[1]; + output_size->data[2] = input->dims->data[2]; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_L2NORM(type) \ + type::L2Normalization( \ + GetTensorData(input), GetTensorDims(input), \ + GetTensorData(output), GetTensorDims(output)) + + if (kernel_type == kReference) { + TF_LITE_L2NORM(reference_ops); + } + if (kernel_type == kGenericOptimized) { + TF_LITE_L2NORM(optimized_ops); + } +#undef TF_LITE_L2NORM + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace l2norm + +TfLiteRegistration* Register_L2NORM_REF() { + static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare, + l2norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_L2NORM_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare, + l2norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_L2_NORMALIZATION() { + return Register_L2NORM_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..30e103f3303484c339ef98e6a68e0438291c102f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class L2NormOpModel : public SingleOpModel { + public: + L2NormOpModel(std::initializer_list input_shape, + ActivationFunctionType activation_type) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions, + CreateL2NormOptions(builder_, activation_type).Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(L2NormOpTest, SimpleTest) { + L2NormOpModel m({1, 1, 1, 6}, ActivationFunctionType_NONE); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..c1c70d0dfa0050dee3815aa15f5d16d2e7ddc721 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace local_response_norm { + +// This file has two implementation of LocalResponseNorm. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = input->dims->data[1]; + output_size->data[2] = input->dims->data[2]; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_LOCAL_RESPONSE_NORM(type) \ + type::LocalResponseNormalization( \ + GetTensorData(input), GetTensorDims(input), params->radius, \ + params->bias, params->alpha, params->beta, GetTensorData(output), \ + GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_LOCAL_RESPONSE_NORM(reference_ops); + } + if (kernel_type == kGenericOptimized) { + TF_LITE_LOCAL_RESPONSE_NORM(optimized_ops); + } +#undef TF_LITE_LOCAL_RESPONSE_NORM + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace local_response_norm + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, local_response_norm::Prepare, + local_response_norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, local_response_norm::Prepare, + local_response_norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION() { + return Register_LOCAL_RESPONSE_NORM_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d75ce258a04c820d8f82735988c01d0154ef36f2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class LocalResponseNormOpModel : public SingleOpModel { + public: + LocalResponseNormOpModel(std::initializer_list input_shape, int radius, + float bias, float alpha, float beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOptions_LocalResponseNormalizationOptions, + CreateLocalResponseNormalizationOptions(builder_, radius, bias, + alpha, beta) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(LocalResponseNormOpTest, SameAsL2Norm) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/1.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 2. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}))); +} + +TEST(LocalResponseNormOpTest, WithAlpha) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 3. + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {-0.275, 0.15, 0.175, 0.3, -0.175, 0.025}))); +} + +TEST(LocalResponseNormOpTest, WithBias) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 5. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02}))); +} + +TEST(LocalResponseNormOpTest, SmallRadius) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-0.264926, 0.125109, 0.140112, 0.267261, -0.161788, 0.0244266}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f73b56ed9790b216adc788490faebaabd2bc756 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc @@ -0,0 +1,204 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// LSH Projection projects an input to a bit vector via locality senstive +// hashing. +// +// Options: +// Sparse: +// Computed bit vector is considered to be sparse. +// Each output element is an int32 made up by multiple bits computed from +// hash functions. +// +// Dense: +// Computed bit vector is considered to be dense. Each output element is +// either 0 or 1 that represents a bit. +// +// Input: +// Tensor[0]: Hash functions. Dim.size == 2, DataType: Float. +// Tensor[0].Dim[0]: Num of hash functions. +// Tensor[0].Dim[1]: Num of projected output bits generated by +// each hash function. +// In sparse case, Tensor[0].Dim[1] + ceil( log2(Tensor[0].Dim[0] )) <= 32. +// +// Tensor[1]: Input. Dim.size >= 1, No restriction on DataType. +// Tensor[2]: Optional, Weight. Dim.size == 1, DataType: Float. +// If not set, each element of input is considered to have same +// weight of 1.0 Tensor[1].Dim[0] == Tensor[2].Dim[0] +// +// Output: +// Sparse: +// Output.Dim == { Tensor[0].Dim[0] } +// A tensor of int32 that represents hash signatures, +// +// NOTE: To avoid collisions across hash functions, an offset value of +// k * (1 << Tensor[0].Dim[1]) will be added to each signature, +// k is the index of the hash function. +// Dense: +// Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] } +// A flattened tensor represents projected bit vectors. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include + +namespace tflite { +namespace ops { +namespace builtin { +namespace lsh_projection { + +TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* hash = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2); + // Support up to 32 bits. + TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32); + + TfLiteTensor* input = GetInput(context, node, 1); + TF_LITE_ENSURE(context, NumDimensions(input) >= 1); + + if (NumInputs(node) == 3) { + TfLiteTensor* weight = GetInput(context, node, 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0), + SizeOfDimension(input, 0)); + } + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); + switch (params->type) { + case kTfLiteLshProjectionSparse: + outputSize->data[0] = SizeOfDimension(hash, 0); + break; + case kTfLiteLshProjectionDense: + outputSize->data[0] = SizeOfDimension(hash, 0) * SizeOfDimension(hash, 1); + break; + default: + return kTfLiteError; + } + return context->ResizeTensor(context, output, outputSize); +} + +// Compute sign bit of dot product of hash(seed, input) and weight. +// NOTE: use float as seed, and convert it to double as a temporary solution +// to match the trained model. This is going to be changed once the new +// model is trained in an optimized method. +// +int RunningSignBit(const TfLiteTensor* input, const TfLiteTensor* weight, + float seed) { + double score = 0.0; + int input_item_bytes = input->bytes / SizeOfDimension(input, 0); + char* input_ptr = input->data.raw; + + const size_t seed_size = sizeof(float); + const size_t key_bytes = sizeof(float) + input_item_bytes; + std::unique_ptr key(new char[key_bytes]); + + for (int i = 0; i < SizeOfDimension(input, 0); ++i) { + // Create running hash id and value for current dimension. + memcpy(key.get(), &seed, seed_size); + memcpy(key.get() + seed_size, input_ptr, input_item_bytes); + + int64_t hash_signature = ::util::Fingerprint64(key.get(), key_bytes); + double running_value = static_cast(hash_signature); + input_ptr += input_item_bytes; + if (weight == nullptr) { + score += running_value; + } else { + score += weight->data.f[i] * running_value; + } + } + + return (score > 0) ? 1 : 0; +} + +void SparseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input, + const TfLiteTensor* weight, int32_t* out_buf) { + int num_hash = SizeOfDimension(hash, 0); + int num_bits = SizeOfDimension(hash, 1); + for (int i = 0; i < num_hash; i++) { + int32_t hash_signature = 0; + for (int j = 0; j < num_bits; j++) { + float seed = hash->data.f[i * num_bits + j]; + int bit = RunningSignBit(input, weight, seed); + hash_signature = (hash_signature << 1) | bit; + } + *out_buf++ = hash_signature + i * (1 << num_bits); + } +} + +void DenseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input, + const TfLiteTensor* weight, int32_t* out_buf) { + int num_hash = SizeOfDimension(hash, 0); + int num_bits = SizeOfDimension(hash, 1); + for (int i = 0; i < num_hash; i++) { + for (int j = 0; j < num_bits; j++) { + float seed = hash->data.f[i * num_bits + j]; + int bit = RunningSignBit(input, weight, seed); + *out_buf++ = bit; + } + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + int32_t* out_buf = GetOutput(context, node, 0)->data.i32; + TfLiteTensor* hash = GetInput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 1); + TfLiteTensor* weight = + NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2); + + switch (params->type) { + case kTfLiteLshProjectionDense: + DenseLshProjection(hash, input, weight, out_buf); + break; + case kTfLiteLshProjectionSparse: + SparseLshProjection(hash, input, weight, out_buf); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} +} // namespace lsh_projection + +TfLiteRegistration* Register_LSH_PROJECTION() { + static TfLiteRegistration r = {nullptr, nullptr, lsh_projection::Resize, + lsh_projection::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..414d728dfc153058ec878d3c766f58e86815cd3f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +class LSHProjectionOpModel : public SingleOpModel { + public: + LSHProjectionOpModel(LSHProjectionType type, + std::initializer_list hash_shape, + std::initializer_list input_shape, + std::initializer_list weight_shape) { + hash_ = AddInput(TensorType_FLOAT32); + input_ = AddInput(TensorType_INT32); + if (weight_shape.size() > 0) { + weight_ = AddInput(TensorType_FLOAT32); + } + output_ = AddOutput(TensorType_INT32); + + SetBuiltinOp(BuiltinOperator_LSH_PROJECTION, + BuiltinOptions_LSHProjectionOptions, + CreateLSHProjectionOptions(builder_, type).Union()); + if (weight_shape.size() > 0) { + BuildInterpreter({hash_shape, input_shape, weight_shape}); + } else { + BuildInterpreter({hash_shape, input_shape}); + } + + output_size_ = 1; + for (int i : hash_shape) { + output_size_ *= i; + if (type == LSHProjectionType_SPARSE) { + break; + } + } + } + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetHash(std::initializer_list data) { + PopulateTensor(hash_, data); + } + + void SetWeight(std::initializer_list f) { PopulateTensor(weight_, f); } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int hash_; + int weight_; + int output_; + + int output_size_; +}; + +TEST(LSHProjectionOpTest2, Dense1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0)); +} + +TEST(LSHProjectionOpTest2, Sparse1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0)); +} + +TEST(LSHProjectionOpTest2, Sparse3DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5}); + + m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912, + 9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c06264d845c24e71647b6fd2374734be32383ef --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -0,0 +1,515 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace lstm { + +// Input Tensors of size {n_batch, n_input} +constexpr int kInputTensor = 0; + +// Input weight tensors of size: {n_cell, n_input} +constexpr int kInputToInputWeightsTensor = 1; // Optional +constexpr int kInputToForgetWeightsTensor = 2; +constexpr int kInputToCellWeightsTensor = 3; +constexpr int kInputToOutputWeightsTensor = 4; + +// Recurrent weight tensors of size {n_cell, n_output} +constexpr int kRecurrentToInputWeightsTensor = 5; // Optional +constexpr int kRecurrentToForgetWeightsTensor = 6; +constexpr int kRecurrentToCellWeightsTensor = 7; +constexpr int kRecurrentToOutputWeightsTensor = 8; + +// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kCellToInputWeightsTensor = 9; // Optional +constexpr int kCellToForgetWeightsTensor = 10; // Optional +constexpr int kCellToOutputWeightsTensor = 11; // Optional + +// Gates bias tensors of size {n_cell} +constexpr int kInputGateBiasTensor = 12; // Optional +constexpr int kForgetGateBiasTensor = 13; +constexpr int kCellGateBiasTensor = 14; +constexpr int kOutputGateBiasTensor = 15; + +// Projection weight tensor of size {n_output, n_cell} +constexpr int kProjectionWeightsTensor = 16; // Optional +// Projection bias tensor of size {n_output} +constexpr int kProjectionBiasTensor = 17; // Optional + +// Output tensors. +constexpr int kScratchBufferTensor = 0; +constexpr int kOutputStateTensor = 1; +constexpr int kCellStateTensor = 2; +constexpr int kOutputTensor = 3; + +// Check that input tensor dimensions matches with each other. +TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, + TfLiteNode* node, int n_input, + int n_output, int n_cell) { + auto* params = reinterpret_cast(node->builtin_data); + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + TF_LITE_ENSURE(context, params->cell_clip >= 0); + TF_LITE_ENSURE(context, params->proj_clip >= 0); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + if (input_to_input_weights) { + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); + } + + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); + + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + if (recurrent_to_input_weights) { + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], + n_output); + } + + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], + n_output); + + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], + n_output); + + // We make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). + const bool cifg_weights_all_or_none = + ((input_to_input_weights != nullptr) && + (recurrent_to_input_weights != nullptr)) || + ((input_to_input_weights == nullptr) && + (recurrent_to_input_weights == nullptr)); + TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + if (cell_to_input_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + if (cell_to_forget_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + if (cell_to_output_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); + } + + // Making sure the peephole weights are there all or none. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool peephole_weights_all_or_none = + ((cell_to_input_weights != nullptr || use_cifg) && + (cell_to_forget_weights != nullptr) && + (cell_to_output_weights != nullptr)) || + ((cell_to_input_weights == nullptr) && + (cell_to_forget_weights == nullptr) && + (cell_to_output_weights == nullptr)); + TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + } else { + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); + } + + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); + + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights) { + TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); + } + + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + if (projection_bias) { + TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + // TODO(ghodrat): make sure this is correct. + const bool projecton_tensors_consistent = + ((projection_weights != nullptr) || (projection_bias == nullptr)); + TF_LITE_ENSURE(context, projecton_tensors_consistent == true); + + return kTfLiteOk; +} + +// Resize the output, state and scratch tensors based on the sizes of the input +// tensors. Also check that the size of the input tensors match each other. +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); + + // Inferring batch size, number of outputs and number of cells from the + // input tensors. + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, input->dims->size > 1); + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + const int n_cell = input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], + n_cell); + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Check that input tensor dimensions matches with each other. + CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + + // Get the pointer to output, state and scratch buffer tensors. + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + // TODO(ghodrat): Modify this as soon as we have a finalized method for + // scratch buffers. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + + // Resize the output and output_state tensors. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); + output_size->data[0] = n_batch; + output_size->data[1] = n_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + + TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); + output_state_size->data[0] = n_batch; + output_state_size->data[1] = n_output; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, output_state, output_state_size)); + + // Resize the output, state and scratch buffer tensors. + TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); + cell_size->data[0] = n_batch; + cell_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state, cell_size)); + + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const bool use_cifg = (input_to_input_weights == nullptr); + if (use_cifg) { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 3; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } else { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Input, Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 4; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } + return kTfLiteOk; +} + +// The LSTM Op engine. +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell, + n_batch, input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell, + n_batch, output_gate_scratch); + + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights->data.f, n_cell, n_input, input->data.f, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights->data.f, n_cell, n_input, input->data.f, n_batch, + forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights->data.f, n_cell, n_input, input->data.f, n_batch, + cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights->data.f, n_cell, n_input, input->data.f, n_batch, + output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights->data.f, n_cell, n_output, output_state->data.f, + n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, + cell_state->data.f, n_batch * n_cell, + cell_state->data.f); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, + cell_state->data.f); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state->data.f); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell, + params->cell_clip, cell_state->data.f); + } + + // For each batch and cell: update the output gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights != nullptr); + const bool use_projection_bias = (projection_bias != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output, + n_batch, output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, n_batch * n_output); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights->data.f, n_output, n_cell, output_gate_scratch, + n_batch, output->data.f, /*result_stride=*/1); + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output->data.f, n_batch * n_output, + params->proj_clip, output->data.f); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output->data.f); + } + tensor_utils::CopyVector(output->data.f, n_batch * n_output, + output_state->data.f); + + return kTfLiteOk; +} + +} // namespace lstm + +TfLiteRegistration* Register_LSTM() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + lstm::Prepare, lstm::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c068286b0d84bcb51ebb0e239350a42863de6523 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -0,0 +1,1087 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite LSTM op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class LSTMOpModel : public SingleOpModel { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + private: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_; + int output_state_; + int cell_state_; + int scratch_buffer_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + +TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}); + + lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}); + + lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}); + + lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}); + + lstm.SetInputGateBias({0., 0., 0., 0.}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToInputWeights( + {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, + -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, + -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + + lstm.SetRecurrentToCellWeights( + {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, + -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + + lstm.SetRecurrentToForgetWeights( + {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, + -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + + lstm.SetRecurrentToOutputWeights( + {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, + 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, + -0.15358765, -0.03716109, 0.12507336, + 0.41193449, -0.20860538, -0.15053082, + 0.09120187, 0.24278517, -0.12222792}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output + i * lstm.num_outputs(); + float* golden_end = golden_start + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, + -0.05163646, -0.42312205, -0.01218222, + 0.24201041, -0.08124574, -0.358325, + -0.04621704, 0.21641694, -0.06471302}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output + i * lstm.num_outputs(); + float* golden_end = golden_start + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights( + {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); + + lstm.SetInputToForgetWeights( + {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, + -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, + -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, + 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, + -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, + -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, + 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, + 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, + 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, + -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, + -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, + -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, + 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, + 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, + -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, + 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); + + lstm.SetInputToCellWeights( + {-0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); + + lstm.SetInputToOutputWeights( + {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); + + lstm.SetInputGateBias( + {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, + -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, + -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, + 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); + + lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}); + + lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}); + + lstm.SetOutputGateBias( + {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, + 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, + 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, + -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); + + lstm.SetRecurrentToInputWeights( + {-0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}); + + lstm.SetRecurrentToForgetWeights( + {-0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}); + + lstm.SetRecurrentToCellWeights( + {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); + + lstm.SetRecurrentToOutputWeights({ + 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, + -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, + -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, + -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, + -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, + -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, + 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, + 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, + -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, + -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, + 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, + -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, + 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, + 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, + 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, + 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, + 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, + -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, + 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, + 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, + -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, + -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, + -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, + -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, + -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, + 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, + -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, + 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, + -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, + -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, + 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, + 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, + -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, + 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, + -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, + -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, + -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, + -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, + 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, + -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, + -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, + -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, + 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, + -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, + 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, + 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, + 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, + 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, + 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }); + + lstm.SetCellToInputWeights( + {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); + + lstm.SetCellToForgetWeights( + {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); + + lstm.SetCellToOutputWeights( + {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); + + lstm.SetProjectionWeights( + {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, + 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, + -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, + -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, + 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, + 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, + 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, + -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, + -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, + 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, + 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, + 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, + 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, + 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, + -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, + 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, + -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, + 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, + -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, + -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, + -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, + 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, + -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, + -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, + 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, + -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, + 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, + 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, + 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, + 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, + -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, + 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, + -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, + -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, + 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, + 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, + -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, + -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, + 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, + -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, + 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, + -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, + -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, + 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, + -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, + -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, + 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, + 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, + 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); + + static float lstm_input[][20] = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, + 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, + 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, + 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, + 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; + + static float lstm_golden_output[][64] = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); + float* batch1_end = batch1_start + lstm.num_inputs(); + lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end); + + lstm.Invoke(); + + float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); + float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); + float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); + float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); + expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc new file mode 100644 index 0000000000000000000000000000000000000000..81c73f2523186c2d4072d56bdc8980fcdbb588a3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -0,0 +1,167 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace mul { + +// This file has three implementation of Mul. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_MUL(type) \ + type::Mul(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops); + } else { + TF_LITE_MUL(optimized_ops); + } +#undef TF_LITE_MUL +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + auto input1_offset = -input1->params.zero_point; + auto input2_offset = -input2->params.zero_point; + auto output_offset = output->params.zero_point; + + int32_t output_multiplier; + int output_shift; + + double real_multiplier = + input1->params.scale * input2->params.scale / output->params.scale; + QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, + &output_shift); + + int32 output_activation_min, output_activation_max; + CalculateActivationRangeUint8(params->activation, output, + &output_activation_min, &output_activation_max); + +#define TF_LITE_MUL(type) \ + type::BroadcastMul(GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, GetTensorData(input2), \ + GetTensorDims(input2), input2_offset, output_offset, \ + output_multiplier, output_shift, output_activation_min, \ + output_activation_max, GetTensorData(output), \ + GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops); + } else { + TF_LITE_MUL(optimized_ops); + } +#undef TF_LITE_MUL +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalFloat(context, node, params, input1, input2, output); + } else if (output->type == kTfLiteUInt8) { + EvalQuantized(context, node, params, input1, input2, output); + } else { + context->ReportError(context, + "Mul only supports FLOAT32 and quantized UINT8 now."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace mul + +TfLiteRegistration* Register_MUL_REF() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL() { +#ifdef USE_NEON + return Register_MUL_NEON_OPT(); +#else + return Register_MUL_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4255cfe18a043c55f3ce7292afdedb6e988a28a2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseMulOpModel : public SingleOpModel { + public: + BaseMulOpModel(TensorData input, TensorData output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input); + input2_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + protected: + int input1_; + int input2_; + int output_; +}; + +class FloatMulOpModel : public BaseMulOpModel { + public: + using BaseMulOpModel::BaseMulOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +// For quantized Mul, the error shouldn't exceed (2*step + step^2). +// The param min=-1.0 & max=1.0 is used in the following tests. +// The tolerance value is ~0.0157. +const float kQuantizedStep = 2.0 / 255.0; +const float kQuantizedTolerance = + 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep; + +class QuantizedMulOpModel : public BaseMulOpModel { + public: + using BaseMulOpModel::BaseMulOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatMulOpTest, NoActivation) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); +} + +TEST(FloatMulOpTest, ActivationRELU1) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 1.0}))); +} + +TEST(FloatMulOpTest, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4, 1.21, 0.2}))) + << "With shape number " << i; + } +} + +TEST(QuantizedMulOpTest, NoActivation) { + QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedTolerance))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..63670efcb1e6349317aa5c75756707fb7a7fa2aa --- /dev/null +++ b/tensorflow/contrib/lite/kernels/op_macros.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ + +#include + +#define TF_LITE_FATAL(msg) \ + do { \ + fprintf(stderr, "%s\n", (msg)); \ + exit(1); \ + } while (0) +#define TF_LITE_ASSERT(x) \ + do { \ + if (!(x)) TF_LITE_FATAL(#x); \ + } while (0) +#define TF_LITE_ASSERT_EQ(x, y) \ + do { \ + if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \ + } while (0) + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..17166715ca30ff3d8ba3d384110e403f8910e39d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -0,0 +1,340 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite LSTM op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +class LSTMOpModel : public SingleOpModel { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + void Verify() { + auto model = tflite::UnPackModel(builder_.GetBufferPointer()); + EXPECT_NE(model, nullptr); + } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + private: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_; + int output_state_; + int cell_state_; + int scratch_buffer_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + + +TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + // Verify the model by unpacking it. + lstm.Verify(); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e90282a43b1b6caf7918b3874fd4273f59e31b7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pad { + +// This file has two implementations of Pad. +enum KernelType { + kReference, + kGenericOptimized, +}; + +// TODO(nupurgarg): Padding represented as a tensor is ignored. Only use the +// `left_padding` and `right_padding` specified in `params`. +struct PadContext { + PadContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLitePadParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Determines size of output tensor. + PadContext op_context(context, node); + int dims = NumDimensions(op_context.input); + TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions); + + // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, dims, 4); + + const TfLiteIntArray* input_size = op_context.input->dims; + TfLiteIntArray* output_size = TfLiteIntArrayCreate(dims); + for (int idx = 0; idx < dims; ++idx) { + TF_LITE_ENSURE_MSG(context, + (op_context.params->before_padding[idx] >= 0 && + op_context.params->after_padding[idx] >= 0), + "Pad value has to be greater than equal to 0."); + output_size->data[idx] = + (input_size->data[idx] + op_context.params->before_padding[idx] + + op_context.params->after_padding[idx]); + } + + return context->ResizeTensor(context, op_context.output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + PadContext op_context(context, node); + + // TODO(nupurgarg): Support different data types. + if (op_context.output->type == kTfLiteFloat32) { + std::vector before_padding( + op_context.params->before_padding, + op_context.params->before_padding + op_context.params->num_dimensions); + std::vector after_padding( + op_context.params->after_padding, + op_context.params->after_padding + op_context.params->num_dimensions); + + // TODO(nupurgarg): Change TOCO's implementation to use padding arrays + // in forward order (depth, width, height, batch). + // Converts from int[] = {depth, width, height, batch} to int[] = {batch, + // height, width, depth} to match TOCO's implementation of pad in + // referenced_ops.h and optimized_ops.h. + std::reverse(before_padding.begin(), before_padding.end()); + std::reverse(after_padding.begin(), after_padding.end()); + +#define TF_LITE_PAD(type) \ + type::Pad(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), before_padding, after_padding, \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops); + } + if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops); + } +#undef TF_LITE_PAD + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace pad + +TfLiteRegistration* Register_PAD_REF() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval}; + return &r; +} + +TfLiteRegistration* Register_PAD_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval}; + return &r; +} + +TfLiteRegistration* Register_PAD() { + return Register_PAD_GENERIC_OPT(); + // return Register_PAD_REF(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3ea9417df0e61dcff7a877726ab91c9b22691ba --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pad_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class PadOpModel : public SingleOpModel { + public: + PadOpModel(std::initializer_list input_shape, + std::initializer_list before_padding, + std::initializer_list after_padding) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_PAD, BuiltinOptions_PadOptions, + CreatePadOptions(builder_, builder_.CreateVector(before_padding), + builder_.CreateVector(after_padding)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(PadOpTest, TooManyDimensions) { + EXPECT_DEATH( + PadOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}), + "dims != 4"); +} + +// TODO(nupurgarg): Test case where before padding and after padding arrays +// don't contain the same number of dimensions. +TEST(PadOpTest, UnequalDimensions) { + EXPECT_DEATH(PadOpModel({1, 1, 2, 1}, {1, 2, 3}, {1, 2, 3}), + "dims != op_context.params->num_dimensions"); +} + +TEST(PadOpTest, InvalidPadValue) { + EXPECT_DEATH(PadOpModel({1, 1, 2, 1}, {0, 1, 2, 0}, {0, -1, -1, 0}), + "Pad value has to be greater than equal to 0."); +} + +TEST(PadOpTest, SimpleTest) { + PadOpModel m({1, 2, 2, 1}, {0, 1, 1, 0}, {0, 1, 1, 0}); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, + 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadOpTest, AdvancedTest) { + // The padding is input in the order of batch, height, width, depth. + PadOpModel m({1, 2, 3, 1}, {0, 0, 1, 0}, {0, 2, 3, 0}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h new file mode 100644 index 0000000000000000000000000000000000000000..3a60274524c468ef29e522de5569e0d8354974c2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/padding.h @@ -0,0 +1,28 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ + +namespace tflite { + +inline int ComputePadding(int stride, int in_size, int filter_size, + int out_size) { + int padding = ((out_size - 1) * stride + filter_size - in_size) / 2; + return padding > 0 ? padding : 0; +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..b79880110897a1438a589d97363fd861c61667e7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -0,0 +1,355 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pooling { + +// This file has two implementation of each pooling op. +enum KernelType { + kReference, + kGenericOptimized, +}; + +enum PoolType { + kAverage, + kMax, + kL2, +}; + +struct OpData { + TfLitePaddingValues padding; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +template +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + int batches = input->dims->data[0]; + int height = input->dims->data[1]; + int width = input->dims->data[2]; + int channels_out = input->dims->data[3]; + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto computeOutSize = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int outWidth = + computeOutSize(width, params->filter_width, params->stride_width); + int outHeight = + computeOutSize(height, params->filter_height, params->stride_height); + + data->padding.height = ComputePadding(params->stride_height, height, + params->filter_height, outHeight); + data->padding.width = ComputePadding(params->stride_width, width, + params->filter_width, outWidth); + + if (input->type == kTfLiteUInt8) { + if (pool_type == kAverage || pool_type == kMax) { + TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); + TF_LITE_ENSURE_EQ(context, input->params.zero_point, + output->params.zero_point); + } + if (pool_type == kL2) { + // We currently don't have a quantized implementation of L2Pool + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + } + } + + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); + outputSize->data[0] = batches; + outputSize->data[1] = outHeight; + outputSize->data[2] = outWidth; + outputSize->data[3] = channels_out; + return context->ResizeTensor(context, output, outputSize); +} + +template +void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_AVERAGE_POOL(reference_ops); + } else { + TF_LITE_AVERAGE_POOL(optimized_ops); + } +#undef TF_LITE_AVERAGE_POOL +} + +template +void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + int32_t activation_min; + int32_t activation_max; + CalculateActivationRangeUint8(params->activation, output, &activation_min, + &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool(GetTensorData(input), GetTensorDims(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, \ + activation_min, activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_AVERAGE_POOL(reference_ops); + } else { + TF_LITE_AVERAGE_POOL(optimized_ops); + } +#undef TF_LITE_AVERAGE_POOL +} + +template +void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_MAX_POOL(type) \ + type::MaxPool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MAX_POOL(reference_ops); + } else { + TF_LITE_MAX_POOL(optimized_ops); + } +#undef TF_LITE_MAX_POOL +} + +template +void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + int32_t activation_min; + int32_t activation_max; + CalculateActivationRangeUint8(params->activation, output, &activation_min, + &activation_max); +#define TF_LITE_MAX_POOL(type) \ + type::MaxPool(GetTensorData(input), GetTensorDims(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), \ + GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MAX_POOL(reference_ops); + } else { + TF_LITE_MAX_POOL(optimized_ops); + } +#undef TF_LITE_MAX_POOL +} + +template +void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_L2_POOL(type) \ + type::L2Pool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_L2_POOL(reference_ops); + } else { + TF_LITE_L2_POOL(optimized_ops); + } +#undef TF_LITE_L2_POOL +} + +#undef TF_LITE_KERNEL_TYPE_DISPATCH + +template +TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + AverageEvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + AverageEvalQuantized(context, node, params, data, input, + output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +template +TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + MaxEvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + MaxEvalQuantized(context, node, params, data, input, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +template +TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + L2EvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + // We don't have a quantized implementation, so just fall through to the + // 'default' case. + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace pooling + +TfLiteRegistration* Register_AVERAGE_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::AverageEval}; + return &r; +} + +TfLiteRegistration* Register_MAX_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::MaxEval}; + return &r; +} + +TfLiteRegistration* Register_L2_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::L2Eval}; + return &r; +} + +TfLiteRegistration* Register_AVERAGE_POOL_GENERIC_OPT() { + static TfLiteRegistration r = { + pooling::Init, pooling::Free, pooling::GenericPrepare, + pooling::AverageEval}; + return &r; +} + +TfLiteRegistration* Register_MAX_POOL_GENERIC_OPT() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::MaxEval}; + return &r; +} + +TfLiteRegistration* Register_L2_POOL_GENERIC_OPT() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::L2Eval}; + return &r; +} + +TfLiteRegistration* Register_AVERAGE_POOL_2D() { + return Register_AVERAGE_POOL_GENERIC_OPT(); +} + +TfLiteRegistration* Register_MAX_POOL_2D() { + return Register_MAX_POOL_GENERIC_OPT(); +} + +TfLiteRegistration* Register_L2_POOL_2D() { + return Register_L2_POOL_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pooling_test.cc b/tensorflow/contrib/lite/kernels/pooling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..01c91b2ba905e249c36af19f175c68a7e7f17f6d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pooling_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BasePoolingOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BasePoolingOpModel(BuiltinOperator type, const TensorData& input, + int filter_width, int filter_height, + const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + + SetBuiltinOp( + type, BuiltinOptions_Pool2DOptions, + CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width, + filter_height, ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatPoolingOpModel : public BasePoolingOpModel { + public: + using BasePoolingOpModel::BasePoolingOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedPoolingOpModel : public BasePoolingOpModel { + public: + using BasePoolingOpModel::BasePoolingOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatPoolingOpTest, AveragePool) { + FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75})); +} + +TEST(QuantizedPoolingOpTest, AveragePool) { + // Choose the input ranges carefully so that the dequantized output matches + // the results of the float model above. + QuantizedPoolingOpModel m( + BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_UINT8, {}, 0, 15.9375}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({2.75, 5.75}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({44, 92})); +} + +TEST(FloatPoolingOpTest, MaxPool) { + FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10})); +} + +TEST(QuantizedPoolingOpTest, MaxPool) { + // Choose the input ranges carefully so that the dequantized output matches + // the results of the float model above. + QuantizedPoolingOpModel m( + BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_UINT8, {}, 0, 15.9375}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({6, 10}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({96, 160})); +} + +TEST(FloatPoolingOpTest, L2Pool) { + FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc new file mode 100644 index 0000000000000000000000000000000000000000..bef6967a90f9cfeef2dd8cda98c887beb46983b2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/register.h" + +namespace tflite { +namespace ops { +namespace builtin { + +TfLiteRegistration* Register_RELU(); +TfLiteRegistration* Register_RELU1(); +TfLiteRegistration* Register_RELU6(); +TfLiteRegistration* Register_TANH(); +TfLiteRegistration* Register_LOGISTIC(); +TfLiteRegistration* Register_AVERAGE_POOL_2D(); +TfLiteRegistration* Register_MAX_POOL_2D(); +TfLiteRegistration* Register_L2_POOL_2D(); +TfLiteRegistration* Register_CONV_2D(); +TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); +TfLiteRegistration* Register_SVDF(); +TfLiteRegistration* Register_RNN(); +TfLiteRegistration* Register_EMBEDDING_LOOKUP(); +TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE(); +TfLiteRegistration* Register_FULLY_CONNECTED(); +TfLiteRegistration* Register_LSH_PROJECTION(); +TfLiteRegistration* Register_HASHTABLE_LOOKUP(); +TfLiteRegistration* Register_SOFTMAX(); +TfLiteRegistration* Register_CONCATENATION(); +TfLiteRegistration* Register_ADD(); +TfLiteRegistration* Register_MUL(); +TfLiteRegistration* Register_L2_NORMALIZATION(); +TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION(); +TfLiteRegistration* Register_LSTM(); +TfLiteRegistration* Register_PAD(); +TfLiteRegistration* Register_RESHAPE(); +TfLiteRegistration* Register_RESIZE_BILINEAR(); +TfLiteRegistration* Register_SKIP_GRAM(); +TfLiteRegistration* Register_SPACE_TO_DEPTH(); + +BuiltinOpResolver::BuiltinOpResolver() { + AddBuiltin(BuiltinOperator_RELU, Register_RELU()); + AddBuiltin(BuiltinOperator_RELU1, Register_RELU1()); + AddBuiltin(BuiltinOperator_RELU6, Register_RELU6()); + AddBuiltin(BuiltinOperator_TANH, Register_TANH()); + AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC()); + AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D()); + AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D()); + AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D()); + AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D()); + AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D()); + AddBuiltin(BuiltinOperator_SVDF, Register_SVDF()); + AddBuiltin(BuiltinOperator_RNN, Register_RNN()); + AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); + AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + Register_EMBEDDING_LOOKUP_SPARSE()); + AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED()); + AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION()); + AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP()); + AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); + AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION()); + AddBuiltin(BuiltinOperator_ADD, Register_ADD()); + AddBuiltin(BuiltinOperator_MUL, Register_MUL()); + AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); + AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + Register_LOCAL_RESPONSE_NORMALIZATION()); + AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_PAD, Register_PAD()); + AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); + AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR()); + AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); + AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH()); +} + +TfLiteRegistration* BuiltinOpResolver::FindOp( + tflite::BuiltinOperator op) const { + auto it = builtins_.find(op); + return it != builtins_.end() ? it->second : nullptr; +} + +TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op) const { + auto it = custom_ops_.find(op); + return it != custom_ops_.end() ? it->second : nullptr; +} + +void BuiltinOpResolver::AddBuiltin(tflite::BuiltinOperator op, + TfLiteRegistration* registration) { + registration->builtin_code = op; + builtins_.insert(std::make_pair(op, registration)); +} + +void BuiltinOpResolver::AddCustom(const char* name, + TfLiteRegistration* registration) { + registration->builtin_code = BuiltinOperator_CUSTOM; + custom_ops_.insert(std::make_pair(std::string(name), registration)); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h new file mode 100644 index 0000000000000000000000000000000000000000..28f5e0fcc80a14cf9fb6fb19b795d0c0d55e0df9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/register.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace builtin { + +class BuiltinOpResolver : public OpResolver { + public: + BuiltinOpResolver(); + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override; + TfLiteRegistration* FindOp(const char* op) const override; + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration); + void AddCustom(const char* name, TfLiteRegistration* registration); + + private: + struct BuiltinOperatorHasher { + size_t operator()(const tflite::BuiltinOperator& x) const { + return std::hash()(static_cast(x)); + } + }; + std::unordered_map + builtins_; + std::unordered_map custom_ops_; +}; + +} // namespace builtin +} // namespace ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3e6ddc9f480e3863cac52157ae28b7329ee2088 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace reshape { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + // TODO(ahentz): we are often given a tensor with the shape but we only pay + // attention to what the shape specified in 'params'. + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Tensorflow's Reshape allows one of the shape components to have the + // special -1 value, meaning it will be calculated automatically based on the + // input. Here we calculate what that dimension should be so that the number + // of output elements in the same as the number of input elements. + int num_input_elements = 1; + for (int i = 0; i < NumDimensions(input); ++i) { + num_input_elements *= SizeOfDimension(input, i); + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions); + int num_output_elements = 1; + int strech_dim = -1; + for (int i = 0; i < params->num_dimensions; ++i) { + int value = params->shape[i]; + if (value == -1) { + TF_LITE_ENSURE_EQ(context, strech_dim, -1); + strech_dim = i; + } else { + num_output_elements *= value; + output_size->data[i] = value; + } + } + if (strech_dim != -1) { + output_size->data[strech_dim] = num_input_elements / num_output_elements; + num_output_elements *= output_size->data[strech_dim]; + } + + TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + memcpy(output->data.raw, input->data.raw, input->bytes); + + return kTfLiteOk; +} + +} // namespace reshape + +TfLiteRegistration* Register_RESHAPE() { + static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare, + reshape::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0fbcf6e6aa311d2cac491336ee54ccf58bbda8fd --- /dev/null +++ b/tensorflow/contrib/lite/kernels/reshape_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ReshapeOpModel : public SingleOpModel { + public: + ReshapeOpModel(std::initializer_list input_shape, + std::initializer_list new_shape) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(ReshapeOpTest, MismatchedDimensions) { + EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {2, 1}), + "num_input_elements != num_output_elements"); +} + +TEST(ReshapeOpTest, TooManyDimensions) { + EXPECT_DEATH( + ReshapeOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}), + "Found too many dimensions"); +} + +TEST(ReshapeOpTest, TooManySpecialDimensions) { + EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {-1, -1, 2, 4}), + "strech_dim != -1"); +} + +TEST(ReshapeOpTest, SimpleTest) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + +TEST(ReshapeOpTest, WithStretchDimension) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc new file mode 100644 index 0000000000000000000000000000000000000000..1613c9a89faa3579b913408cc09cdad7f942cb99 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace resize_bilinear { + +// This file has three implementation of RESIZE_BILINEAR. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(ahentz): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + // TODO(ahentz): Our current implementations only support float32. + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = params->new_height; + output_size->data[2] = params->new_width; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // We have to fake a tensor here, to satisfy ResizeBilinear(). + int32 output_size_data[2] = {params->new_height, params->new_width}; + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_RESIZE_BILINEAR(type) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + output_size_data, GetTensorDims({1, 1, 1, 2}), \ + GetTensorData(output), GetTensorDims(output)) + + if (kernel_type == kReference) { + TF_LITE_RESIZE_BILINEAR(reference_ops); + } + if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { + TF_LITE_RESIZE_BILINEAR(optimized_ops); + } +#undef TF_LITE_RESIZE_BILINEAR + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace resize_bilinear + +TfLiteRegistration* Register_RESIZE_BILINEAR_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR_NEON_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR() { +#ifdef USE_NEON + return Register_RESIZE_BILINEAR_NEON_OPT(); +#else + return Register_RESIZE_BILINEAR_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..314a71e210d9b5ea75bb137ef228273ef48f28b5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ResizeBilinearOpModel : public SingleOpModel { + public: + ResizeBilinearOpModel(std::initializer_list input_shape, int new_height, + int new_width) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, + CreateResizeBilinearOptions(builder_, new_height, new_width).Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(ResizeBilinearOpTest, HorizontalResize) { + ResizeBilinearOpModel m({1, 1, 2, 1}, 1, 3); + m.SetInput({3, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(ResizeBilinearOpTest, VerticalResize) { + ResizeBilinearOpModel m({1, 2, 1, 1}, 3, 1); + m.SetInput({3, 9}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResize) { + ResizeBilinearOpModel m({1, 2, 2, 1}, 3, 3); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { + ResizeBilinearOpModel m({2, 2, 2, 1}, 3, 3); + m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); +} + +TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { + ResizeBilinearOpModel m({1, 2, 2, 2}, 3, 3); + m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc new file mode 100644 index 0000000000000000000000000000000000000000..c90a15b3a2e79028128260e579f41742a46289f6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/skip_gram.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Generate a list of skip grams from an input. +// +// Options: +// ngram_size: num of words for each output item. +// max_skip_size: max num of words to skip. +// The op generates ngrams when it is 0. +// include_all_ngrams: include all ngrams with size up to ngram_size. +// +// Input: +// A string tensor to generate n-grams. +// Dim = {1} +// +// Output: +// A list of strings, each of which contains ngram_size words. +// Dim = {num_ngram} + +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TF_LITE_ENSURE_EQ(context, GetInput(context, node, 0)->type, kTfLiteString); + TF_LITE_ENSURE_EQ(context, GetOutput(context, node, 0)->type, kTfLiteString); + return kTfLiteOk; +} + +bool ShouldIncludeCurrentNgram(const TfLiteSkipGramParams* params, int size) { + if (size <= 0) { + return false; + } + if (params->include_all_ngrams) { + return size <= params->ngram_size; + } else { + return size == params->ngram_size; + } +} + +bool ShouldStepInRecursion(const TfLiteSkipGramParams* params, + const std::vector& stack, int stack_idx, + int num_words) { + // If current stack size and next word enumeration are within valid range. + if (stack_idx < params->ngram_size && stack[stack_idx] + 1 < num_words) { + // If this stack is empty, step in for first word enumeration. + if (stack_idx == 0) { + return true; + } + // If next word enumeration are within the range of max_skip_size. + // NOTE: equivalent to + // next_word_idx = stack[stack_idx] + 1 + // next_word_idx - stack[stack_idx-1] <= max_skip_size + 1 + if (stack[stack_idx] - stack[stack_idx - 1] <= params->max_skip_size) { + return true; + } + } + return false; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + // Split sentence to words. + std::vector words; + tflite::StringRef strref = tflite::GetString(GetInput(context, node, 0), 0); + int prev_idx = 0; + for (int i = 1; i < strref.len; i++) { + if (isspace(*(strref.str + i))) { + if (i > prev_idx && !isspace(*(strref.str + prev_idx))) { + words.push_back({strref.str + prev_idx, i - prev_idx}); + } + prev_idx = i + 1; + } + } + if (strref.len > prev_idx) { + words.push_back({strref.str + prev_idx, strref.len - prev_idx}); + } + + // Generate n-grams recursively. + tflite::DynamicBuffer buf; + if (words.size() < params->ngram_size) { + buf.WriteToTensor(GetOutput(context, node, 0)); + return kTfLiteOk; + } + + // Stack stores the index of word used to generate ngram. + // The size of stack is the size of ngram. + std::vector stack(params->ngram_size, 0); + // Stack index that indicates which depth the recursion is operating at. + int stack_idx = 1; + int num_words = words.size(); + + while (stack_idx >= 0) { + if (ShouldStepInRecursion(params, stack, stack_idx, num_words)) { + // When current depth can fill with a new word + // and the new word is within the max range to skip, + // fill this word to stack, recurse into next depth. + stack[stack_idx]++; + stack_idx++; + if (stack_idx < params->ngram_size) { + stack[stack_idx] = stack[stack_idx - 1]; + } + } else { + if (ShouldIncludeCurrentNgram(params, stack_idx)) { + // Add n-gram to tensor buffer when the stack has filled with enough + // words to generate the ngram. + std::vector gram(stack_idx); + for (int i = 0; i < stack_idx; i++) { + gram[i] = words[stack[i]]; + } + buf.AddJoinedString(gram, ' '); + } + // When current depth cannot fill with a valid new word, + // and not in last depth to generate ngram, + // step back to previous depth to iterate to next possible word. + stack_idx--; + } + } + + buf.WriteToTensor(GetOutput(context, node, 0)); + return kTfLiteOk; +} +} // namespace + +TfLiteRegistration* Register_SKIP_GRAM() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/skip_gram_test.cc b/tensorflow/contrib/lite/kernels/skip_gram_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..185b64cb44969b57588ea5d0b40f55b6ddf8e11f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/skip_gram_test.cc @@ -0,0 +1,257 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +static char kSentence[] = "The quick\t brown fox\n jumps over\n the lazy dog!"; + +class SkipGramOp : public SingleOpModel { + public: + SkipGramOp(int ngram_size, int max_skip_size, bool include_all_ngrams) { + input_ = AddInput(TensorType_STRING); + output_ = AddOutput(TensorType_STRING); + + SetBuiltinOp(BuiltinOperator_SKIP_GRAM, BuiltinOptions_SkipGramOptions, + CreateSkipGramOptions(builder_, ngram_size, max_skip_size, + include_all_ngrams) + .Union()); + BuildInterpreter({{1}}); + } + void SetInput(const string& content) { + PopulateStringTensor(input_, {content}); + } + + std::vector GetOutput() { + std::vector ans; + TfLiteTensor* tensor = interpreter_->tensor(output_); + + int num = GetStringCount(tensor); + for (int i = 0; i < num; i++) { + StringRef strref = GetString(tensor, i); + ans.push_back(string(strref.str, strref.len)); + } + return ans; + } + + private: + int input_; + int output_; +}; + +TEST(SkipGramTest, TestUnigram) { + SkipGramOp m(1, 0, false); + + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), testing::UnorderedElementsAreArray( + {"The", "quick", "brown", "fox", "jumps", + "over", "the", "lazy", "dog!"})); +} + +TEST(SkipGramTest, TestBigram) { + SkipGramOp m(2, 0, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!"})); +} + +TEST(SkipGramTest, TestAllBigram) { + SkipGramOp m(2, 0, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", + "lazy", "dog!", + // Bigram + "The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!"})); +} + +TEST(SkipGramTest, TestAllTrigram) { + SkipGramOp m(3, 0, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", + "lazy", "dog!", + // Bigram + "The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!", + // Trigram + "The quick brown", "quick brown fox", "brown fox jumps", + "fox jumps over", "jumps over the", "over the lazy", + "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip1Bigram) { + SkipGramOp m(2, 1, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "The brown", "quick brown", "quick fox", "brown fox", + "brown jumps", "fox jumps", "fox over", "jumps over", "jumps the", + "over the", "over lazy", "the lazy", "the dog!", "lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip2Bigram) { + SkipGramOp m(2, 2, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "The brown", "The fox", "quick brown", + "quick fox", "quick jumps", "brown fox", "brown jumps", + "brown over", "fox jumps", "fox over", "fox the", + "jumps over", "jumps the", "jumps lazy", "over the", + "over lazy", "over dog!", "the lazy", "the dog!", + "lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip1Trigram) { + SkipGramOp m(3, 1, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick brown", "The quick fox", "The brown fox", + "The brown jumps", "quick brown fox", "quick brown jumps", + "quick fox jumps", "quick fox over", "brown fox jumps", + "brown fox over", "brown jumps over", "brown jumps the", + "fox jumps over", "fox jumps the", "fox over the", + "fox over lazy", "jumps over the", "jumps over lazy", + "jumps the lazy", "jumps the dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip2Trigram) { + SkipGramOp m(3, 2, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick brown", "The quick fox", "The quick jumps", + "The brown fox", "The brown jumps", "The brown over", + "The fox jumps", "The fox over", "The fox the", + "quick brown fox", "quick brown jumps", "quick brown over", + "quick fox jumps", "quick fox over", "quick fox the", + "quick jumps over", "quick jumps the", "quick jumps lazy", + "brown fox jumps", "brown fox over", "brown fox the", + "brown jumps over", "brown jumps the", "brown jumps lazy", + "brown over the", "brown over lazy", "brown over dog!", + "fox jumps over", "fox jumps the", "fox jumps lazy", + "fox over the", "fox over lazy", "fox over dog!", + "fox the lazy", "fox the dog!", "jumps over the", + "jumps over lazy", "jumps over dog!", "jumps the lazy", + "jumps the dog!", "jumps lazy dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestAllSkip2Trigram) { + SkipGramOp m(3, 2, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", + "dog!", + // Bigram + "The quick", "The brown", "The fox", "quick brown", "quick fox", + "quick jumps", "brown fox", "brown jumps", "brown over", "fox jumps", + "fox over", "fox the", "jumps over", "jumps the", "jumps lazy", + "over the", "over lazy", "over dog!", "the lazy", "the dog!", + "lazy dog!", + // Trigram + "The quick brown", "The quick fox", "The quick jumps", + "The brown fox", "The brown jumps", "The brown over", + "The fox jumps", "The fox over", "The fox the", "quick brown fox", + "quick brown jumps", "quick brown over", "quick fox jumps", + "quick fox over", "quick fox the", "quick jumps over", + "quick jumps the", "quick jumps lazy", "brown fox jumps", + "brown fox over", "brown fox the", "brown jumps over", + "brown jumps the", "brown jumps lazy", "brown over the", + "brown over lazy", "brown over dog!", "fox jumps over", + "fox jumps the", "fox jumps lazy", "fox over the", "fox over lazy", + "fox over dog!", "fox the lazy", "fox the dog!", "jumps over the", + "jumps over lazy", "jumps over dog!", "jumps the lazy", + "jumps the dog!", "jumps lazy dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSingleWord) { + SkipGramOp m(1, 1, false); + m.SetInput("Hi"); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre("Hi")); +} + +TEST(SkipGramTest, TestWordsLessThanGram) { + SkipGramOp m(3, 1, false); + m.SetInput("Hi hi"); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), std::vector()); +} + +TEST(SkipGramTest, TestEmptyInput) { + SkipGramOp m(1, 1, false); + m.SetInput(""); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre()); +} + +TEST(SkipGramTest, TestWhitespaceInput) { + SkipGramOp m(1, 1, false); + m.SetInput(" "); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre()); +} + +TEST(SkipGramTest, TestInputWithExtraSpace) { + SkipGramOp m(1, 1, false); + m.SetInput(" Hello world ! "); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre("Hello", "world", "!")); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c5338ff0fd26337c9adc8e0b94a0a88edfde37f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/softmax_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite SOFTMAX op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +class SoftmaxOpModel : public SingleOpModel { + public: + SoftmaxOpModel(int batches, int size, float beta) + : batches_(batches), input_size_(size), beta_(beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, beta_).Union()); + BuildInterpreter({{batches_, input_size_}}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; + + int batches_; + int input_size_; + float beta_; +}; + +TEST(SoftmaxOpTest, SimpleTest) { + SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647, + 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231}, + 1e-6))); +} + +TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) { + const int batch_size = 2; + const int input_size = 5; + const float beta = 1.0; + static float input_buffer[] = { + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1 + }; + + SoftmaxOpModel m(batch_size, input_size, beta); + + m.SetInput(0, input_buffer, input_buffer + input_size * batch_size); + + m.Invoke(); + + std::unique_ptr output_buffer(new float[input_size * batch_size]); + static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, + {1, 0, 0, input_size}}; + tflite::reference_ops::Softmax(input_buffer, input_dims, beta, + output_buffer.get(), input_dims); + + std::vector expected; + expected.insert(expected.end(), output_buffer.get(), + output_buffer.get() + input_size * batch_size); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6))); +} + +TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) { + const int batch_size = 2; + const int input_size = 5; + const float beta = 0.5; + static float input_buffer[] = { + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1 + }; + + SoftmaxOpModel m(batch_size, input_size, beta); + + m.SetInput(0, input_buffer, input_buffer + input_size * batch_size); + + m.Invoke(); + + std::unique_ptr output_buffer(new float[input_size * batch_size]); + static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, + {1, 0, 0, input_size}}; + tflite::reference_ops::Softmax(input_buffer, input_dims, beta, + output_buffer.get(), input_dims); + + std::vector expected; + expected.insert(expected.end(), output_buffer.get(), + output_buffer.get() + input_size * batch_size); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb2e509c9811b1469c4d3f676532edff570a6c4a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc @@ -0,0 +1,146 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace space_to_depth { + +// This file has two implementation of SpaceToDepth. Note that SpaceToDepth +// only works on 4D tensors. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + auto data_type = output->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 || + data_type == kTfLiteInt32 || data_type == kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + const int block_size = params->block_size; + const int input_height = input->dims->data[1]; + const int input_width = input->dims->data[2]; + int output_height = input_height / block_size; + int output_width = input_width / block_size; + + TF_LITE_ENSURE_EQ(context, input_height, output_height * block_size); + TF_LITE_ENSURE_EQ(context, input_width, output_width * block_size); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = output_height; + output_size->data[2] = output_width; + output_size->data[3] = input->dims->data[3] * block_size * block_size; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + +#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \ + type::SpaceToDepth( \ + GetTensorData(input), GetTensorDims(input), params->block_size, \ + GetTensorData(output), GetTensorDims(output)) + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, float); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, uint8_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, int64_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } +#undef TF_LITE_SPACE_TO_DEPTH + + return kTfLiteOk; +} + +} // namespace space_to_depth + +TfLiteRegistration* Register_SPACE_TO_DEPTH_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_depth::Prepare, + space_to_depth::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_DEPTH_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_depth::Prepare, + space_to_depth::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_DEPTH() { + return Register_SPACE_TO_DEPTH_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..997f354861a235fb511235e4d64544dc8c3ddb34 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class SpaceToDepthOpModel : public SingleOpModel { + public: + SpaceToDepthOpModel(const TensorData& tensor_data, int block_size) { + input_ = AddInput(tensor_data); + output_ = AddOutput(tensor_data); + SetBuiltinOp(BuiltinOperator_SPACE_TO_DEPTH, + BuiltinOptions_SpaceToDepthOptions, + CreateSpaceToDepthOptions(builder_, block_size).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + template + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(SpaceToDepthOpModel, BadBlockSize) { + EXPECT_DEATH(SpaceToDepthOpModel({TensorType_FLOAT32, {1, 2, 2, 1}}, 3), + "Cannot allocate tensors"); +} + +TEST(SpaceToDepthOpModel, Float32) { + SpaceToDepthOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, 2); + m.SetInput({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 8)); +} + +TEST(SpaceToDepthOpModel, Uint8) { + SpaceToDepthOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, 2); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(SpaceToDepthOpModel, Int32) { + SpaceToDepthOpModel m({TensorType_INT32, {1, 2, 2, 3}}, 2); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 12)); +} + +TEST(SpaceToDepthOpModel, Int64) { + SpaceToDepthOpModel m({TensorType_INT64, {1, 4, 4, 1}}, 2); + m.SetInput({1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 2, 4)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc new file mode 100644 index 0000000000000000000000000000000000000000..72f705fe4242b01c1516c99d3500484e8729fd9a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -0,0 +1,222 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace svdf { + +constexpr int kInputTensor = 0; +constexpr int kWeightsFeatureTensor = 1; +constexpr int kWeightsTimeTensor = 2; +constexpr int kBiasTensor = 3; +constexpr int kStateTensor = 0; +constexpr int KOutputTensor = 1; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* scratch_tensor_index = new int; + context->AddTensors(context, 1, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + int* scratch_tensor_index = reinterpret_cast(node->user_data); + + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* weights_feature = + &context->tensors[node->inputs->data[kWeightsFeatureTensor]]; + TfLiteTensor* weights_time = + &context->tensors[node->inputs->data[kWeightsTimeTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int num_filters = weights_feature->dims->data[0]; + TF_LITE_ASSERT_EQ(num_filters % rank, 0); + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + TF_LITE_ASSERT_EQ(input->dims->data[1], weights_feature->dims->data[1]); + TF_LITE_ASSERT_EQ(weights_time->dims->data[0], num_filters); + + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + if (bias) { + TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); + } + + TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]]; + + // Resize state. + // For each batch, the state is a 2-D tensor: memory_size * num_filters + // The left most column is used to save current cycle activation. + // The right most column is used to save temporary output which will be + // reduced to num_units outputs. + TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2); + state_size_array->data[0] = batch_size; + state_size_array->data[1] = memory_size * num_filters; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, state, state_size_array)); + + // Mark state as a persistent tensor. + state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); + + // Resize scratch. + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(1); + node->temporaries->data[0] = *scratch_tensor_index; + + TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2); + scratch_size_array->data[0] = batch_size; + scratch_size_array->data[1] = num_filters; + + TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; + scratch_tensor->type = input->type; + scratch_tensor->allocation_type = kTfLiteArenaRw; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor, + scratch_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* weights_feature = + &context->tensors[node->inputs->data[kWeightsFeatureTensor]]; + TfLiteTensor* weights_time = + &context->tensors[node->inputs->data[kWeightsTimeTensor]]; + + TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]]; + TfLiteTensor* scratch = &context->tensors[node->temporaries->data[0]]; + + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int num_filters = weights_feature->dims->data[0]; + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + // Clear the activation (state left most column). + // TODO(ghodrat): Add a test which initialize state with invalid values in + // left most column and make sure it passes. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + for (int c = 0; c < num_filters; c++) { + float* state_ptr = state_ptr_batch + c * memory_size; + state_ptr[memory_size - 1] = 0.0; + } + } + + // Compute conv1d(inputs, weights_feature). + // The state left most column is used to save current cycle activation. This + // is achieved by starting at state->data.f[memory_size - 1] and having the + // stride equal to memory_size. + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + weights_feature->data.f, num_filters, input_size, input->data.f, + batch_size, &state->data.f[memory_size - 1], memory_size); + + // Compute matmul(state, weights_time). + // The right most column is used to save temporary output (with the size of + // num_filters). This is achieved by starting at state->data.f and having the + // stride equal to memory_size. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::BatchVectorBatchVectorDotProduct( + weights_time->data.f, state_ptr_batch, memory_size, num_filters, + scratch_ptr_batch, /*result_stride=*/1); + } + + // Initialize output with bias if provided. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // Reduction sum + for (int b = 0; b < batch_size; b++) { + float* output_ptr_batch = output->data.f + b * num_units; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch, + num_units, rank); + } + + // Apply activation. + for (int b = 0; b < batch_size; b++) { + float* output_ptr_batch = output->data.f + b * num_units; + tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units, + params->activation, output_ptr_batch); + } + + // Right shift the state. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + for (int f = 0; f < num_filters; f++) { + tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size, + /*shift_value=*/0.0); + state_ptr_batch += memory_size; + } + } + return kTfLiteOk; +} + +} // namespace svdf + +TfLiteRegistration* Register_SVDF() { + static TfLiteRegistration r = {svdf::Init, svdf::Free, svdf::Prepare, + svdf::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4de2ceaf053df31a4bc857fb250db416c071e80f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -0,0 +1,312 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite SVDF op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float svdf_input[] = { + 0.12609188, -0.46347019, -0.89598465, + 0.35867718, 0.36897406, 0.73463392, + + 0.14278367, -1.64410412, -0.75222826, + -0.57290924, 0.12729003, 0.7567004, + + 0.49837467, 0.19278903, 0.26584083, + 0.17660543, 0.52949083, -0.77931279, + + -0.11186574, 0.13164264, -0.05349274, + -0.72674477, -0.5683046, 0.55900657, + + -0.68892461, 0.37783599, 0.18263303, + -0.63690937, 0.44483393, -0.71817774, + + -0.81299269, -0.86831826, 1.43940818, + -0.95760226, 1.82078898, 0.71135032, + + -1.45006323, -0.82251364, -1.69082689, + -1.65087092, -1.89238167, 1.54172635, + + 0.03966608, -0.24936394, -0.77526885, + 2.06740379, -1.51439476, 1.43768692, + + 0.11771342, -0.23761693, -0.65898693, + 0.31088525, -1.55601168, -0.87661445, + + -0.89477462, 1.67204106, -0.53235275, + -0.6230064, 0.29819036, 1.06939757, +}; + +static float svdf_golden_output_rank_1[] = { + 0.014899, -0.0517661, -0.143725, -0.00271883, + -0.03004015, 0.09565311, 0.1587342, 0.00784263, + + 0.068281, -0.162217, -0.152268, 0.00323521, + 0.01582633, 0.03858774, -0.03001583, -0.02671271, + + -0.0317821, -0.0333089, 0.0609602, 0.0333759, + -0.01432795, 0.05524484, 0.1101355, -0.02382665, + + -0.00623099, -0.077701, -0.391193, -0.0136691, + -0.02333033, 0.02293761, 0.12338032, 0.04326871, + + 0.201551, -0.164607, -0.179462, -0.0592739, + 0.01064911, -0.17503069, 0.07821996, -0.00224009, + + 0.0886511, -0.0875401, -0.269283, 0.0281379, + -0.02282338, 0.09741908, 0.32973239, 0.12281385, + + -0.201174, -0.586145, -0.628624, -0.0330412, + 0.24780814, -0.39304617, -0.22473189, 0.02589256, + + -0.0839096, -0.299329, 0.108746, 0.109808, + 0.10084175, -0.06416984, 0.28936723, 0.0026358, + + 0.419114, -0.237824, -0.422627, 0.175115, + -0.2314795, -0.18584411, -0.4228974, -0.12928449, + + 0.36726, -0.522303, -0.456502, -0.175475, + 0.17012937, -0.34447709, 0.38505614, -0.28158101, +}; + +static float svdf_golden_output_rank_2[] = { + -0.09623547, -0.10193135, 0.11083051, -0.0347917, + 0.1141196, 0.12965347, -0.12652366, 0.01007236, + + -0.16396809, -0.21247184, 0.11259045, -0.04156673, + 0.10132131, -0.06143532, -0.00924693, 0.10084561, + + 0.01257364, 0.0506071, -0.19287863, -0.07162561, + -0.02033747, 0.22673416, 0.15487903, 0.02525555, + + -0.1411963, -0.37054959, 0.01774767, 0.05867489, + 0.09607603, -0.0141301, -0.08995658, 0.12867066, + + -0.27142537, -0.16955489, 0.18521598, -0.12528358, + 0.00331409, 0.11167502, 0.02218599, -0.07309391, + + 0.09593632, -0.28361851, -0.0773851, 0.17199151, + -0.00075242, 0.33691186, -0.1536046, 0.16572715, + + -0.27916506, -0.27626723, 0.42615682, 0.3225764, + -0.37472126, -0.55655634, -0.05013514, 0.289112, + + -0.24418658, 0.07540751, -0.1940318, -0.08911639, + 0.00732617, 0.46737891, 0.26449674, 0.24888524, + + -0.17225097, -0.54660404, -0.38795233, 0.08389944, + 0.07736043, -0.28260678, 0.15666828, 1.14949894, + + -0.57454878, -0.64704704, 0.73235172, -0.34616736, + 0.21120001, -0.22927976, 0.02455296, -0.35906726, +}; + +// Derived class of SingleOpModel, which is used to test SVDF TFLite op. +class SVDFOpModel : public SingleOpModel { + public: + SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank) + : batches_(batches), + units_(units), + input_size_(input_size), + memory_size_(memory_size), + rank_(rank) { + input_ = AddInput(TensorType_FLOAT32); + weights_feature_ = AddInput(TensorType_FLOAT32); + weights_time_ = AddInput(TensorType_FLOAT32); + bias_ = AddNullInput(); + state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, + CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); + BuildInterpreter({ + {batches_, input_size_}, // Input tensor + {units_ * rank, input_size_}, // weights_feature tensor + {units_ * rank, memory_size_}, // weights_time tensor + {units_} // bias tensor + }); + } + + // Populates the weights_feature tensor. + void SetWeightsFeature(std::initializer_list f) { + PopulateTensor(weights_feature_, f); + } + + // Populates the weights_time tensor. + void SetWeightsTime(std::initializer_list f) { + PopulateTensor(weights_time_, f); + } + + // Populates the input tensor. + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + // Resets the state of SVDF op by filling it with 0's. + void ResetState() { + const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + // Extracts the output tensor from the SVDF op. + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + private: + int input_; + int weights_feature_; + int weights_time_; + int bias_; + int state_; + int output_; + + int batches_; + int units_; + int input_size_; + int memory_size_; + int rank_; +}; + +TEST(SVDFOpTest, BlackBoxTestRank1) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, + 0.22197971, 0.12416199, 0.27901134, 0.27557442, + 0.3905206, -0.36137494, -0.06634006, -0.10640851}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.ResetState(); + const int svdf_num_batches = svdf.num_batches(); + const int svdf_input_size = svdf.input_size(); + const int svdf_num_units = svdf.num_units(); + const int input_sequence_size = + sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF op + // and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf.SetInput(0, batch_start, batch_end); + + svdf.Invoke(); + + float* golden_start = + svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches; + float* golden_end = golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(SVDFOpTest, BlackBoxTestRank2) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, + 0.12416199, 0.15785322, 0.27901134, 0.3905206, + 0.21931258, -0.36137494, -0.10640851, 0.31053296, + -0.36118156, -0.0976817, -0.36916667, 0.22197971, + 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.ResetState(); + const int svdf_num_batches = svdf.num_batches(); + const int svdf_input_size = svdf.input_size(); + const int svdf_num_units = svdf.num_units(); + const int input_sequence_size = + sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF op + // and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf.SetInput(0, batch_start, batch_end); + + svdf.Invoke(); + + float* golden_start = + svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches; + float* golden_end = golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..f716ba8741fd469e7ee405ac300924b53c5c48e5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/test_util.h" + +#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/core/platform/logging.h" + +namespace tflite { + +using ::testing::FloatNear; +using ::testing::Matcher; + +namespace { +template +std::pair QuantizationParams(float f_min, float f_max) { + // These are required by many quantized operations. + CHECK_LE(f_min, 0); + CHECK_GE(f_max, 0); + T q_min = std::numeric_limits::min(); + T q_max = std::numeric_limits::max(); + float range = q_max - q_min; + float scale = (f_max - f_min) / range; + int32_t zero_point = std::min( + q_max, + std::max(q_min, static_cast(std::round(q_min - f_min / scale)))); + return {scale, zero_point}; +} +} // namespace + +std::vector> ArrayFloatNear(const std::vector& values, + float max_abs_error) { + std::vector> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(FloatNear(v, max_abs_error)); + } + return matchers; +} + +int SingleOpModel::AddTensor(TensorData t) { + int id = tensors_.size(); + + // This is slightly different depending on whether we are adding a + // quantized or a regular tensor. + bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0); + + flatbuffers::Offset q_params = 0; + + if (is_quantized) { + if (t.min != 0 || t.max != 0) { + if (t.type == TensorType_UINT8) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); + } else if (t.type == TensorType_INT32) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); + } else { + LOG(FATAL) << "No support for the requested quantized type"; + } + t.min = 0; + t.max = 0; + } + + q_params = CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, builder_.CreateVector({t.scale}), + builder_.CreateVector({t.zero_point})); + } + + tensors_.push_back(CreateTensor(builder_, builder_.CreateVector({}), + t.type, /*buffer=*/0, + /*name=*/0, q_params)); + + tensor_data_[id] = t; + + return id; +} + +int SingleOpModel::AddInput(const TensorData& t) { + int id = AddTensor(t); + inputs_.push_back(id); + return id; +} + +int SingleOpModel::AddNullInput() { + int id = kOptionalTensor; + inputs_.push_back(id); + return id; +} + +int SingleOpModel::AddOutput(const TensorData& t) { + int id = AddTensor(t); + outputs_.push_back(id); + return id; +} + +void SingleOpModel::SetBuiltinOp(BuiltinOperator type, + BuiltinOptions builtin_options_type, + flatbuffers::Offset builtin_options) { + opcodes_.push_back(CreateOperatorCode(builder_, type, 0)); + operators_.push_back(CreateOperator( + builder_, /*opcode_index=*/0, builder_.CreateVector(inputs_), + builder_.CreateVector(outputs_), builtin_options_type, + builtin_options, + /*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS)); +} + +void SingleOpModel::SetCustomOp( + const string& name, const std::vector& custom_option, + const std::function& registeration) { + custom_registrations_[name] = registeration; + opcodes_.push_back( + CreateOperatorCodeDirect(builder_, BuiltinOperator_CUSTOM, name.data())); + operators_.push_back(CreateOperator( + builder_, /*opcode_index=*/0, builder_.CreateVector(inputs_), + builder_.CreateVector(outputs_), BuiltinOptions_NONE, 0, + builder_.CreateVector(custom_option), + CustomOptionsFormat_FLEXBUFFERS)); +} + +void SingleOpModel::BuildInterpreter( + std::vector> input_shapes) { + auto opcodes = builder_.CreateVector(opcodes_); + auto operators = builder_.CreateVector(operators_); + auto tensors = builder_.CreateVector(tensors_); + auto inputs = builder_.CreateVector(inputs_); + auto outputs = builder_.CreateVector(outputs_); + // Create a single subgraph + std::vector> subgraphs; + auto subgraph = CreateSubGraph(builder_, tensors, inputs, outputs, operators); + subgraphs.push_back(subgraph); + auto subgraphs_flatbuffer = builder_.CreateVector(subgraphs); + + std::vector> buffers_vec; + auto buffers = builder_.CreateVector(buffers_vec); + auto description = builder_.CreateString("programmatic model"); + builder_.Finish(CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs_flatbuffer, description, buffers)); + + auto* model = GetModel(builder_.GetBufferPointer()); + + ops::builtin::BuiltinOpResolver builtins; + for (const auto& reg : custom_registrations_) { + builtins.AddCustom(reg.first.data(), reg.second()); + } + InterpreterBuilder(model, builtins)(&interpreter_); + + CHECK(interpreter_ != nullptr); + + int i = 0; + for (const auto& shape : input_shapes) { + int input_idx = interpreter_->inputs()[i++]; + if (input_idx == kOptionalTensor) continue; + CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk); + } + CHECK(interpreter_->AllocateTensors() == kTfLiteOk) + << "Cannot allocate tensors"; +} + +void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); } + +int32_t SingleOpModel::GetTensorSize(int index) const { + TfLiteTensor* t = interpreter_->tensor(index); + CHECK(t); + int total_size = 1; + for (int i = 0; i < t->dims->size; ++i) { + total_size *= t->dims->data[i]; + } + return total_size; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..adcdeddbfc9d3b3313b09cd6310171160e0be645 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ + +#include + +#include +#include + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tflite { + +// A gmock matcher that check that elements of a float vector match to a given +// tolerance. +std::vector<::testing::Matcher> ArrayFloatNear( + const std::vector& values, float max_abs_error = 1e-5); + +template +inline std::vector Quantize(const std::vector& data, float scale, + int32_t zero_point) { + std::vector q; + for (float f : data) { + q.push_back(std::max( + std::numeric_limits::min(), + std::min(std::numeric_limits::max(), + static_cast(std::round(zero_point + (f / scale)))))); + } + return q; +} + +template +inline std::vector Dequantize(const std::vector& data, float scale, + int32_t zero_point) { + std::vector f; + for (T q : data) { + f.push_back(scale * (q - zero_point)); + } + return f; +} + +// A test model that contains a single operator. All operator inputs and +// output are external to the model, so the tests can directly access them. +// Typical usage: +// SingleOpModel m; +// int a = m.AddInput({TensorType_FLOAT32, a_shape}); +// int b = m.AddInput({TensorType_FLOAT32, b_shape}); +// int c = m.AddOutput({TensorType_FLOAT32, {}}); +// m.SetBuiltinOp(...); +// m.BuildInterpreter({GetShape(a), GetShape(b)}); +// m.PopulateTensor(a, {...}); +// m.PopulateTensor(b, {...}); +// m.Invoke(); +// EXPECT_THAT(m.ExtractVector(c), ArrayFloatNear({...})); +// + +// A helper struct to construct test tensors. This is particularly useful for +// quantized tensor which must have their scale and zero_point defined before +// the actual data is known. This mimics what happens in practice: quantization +// parameters are calculate during training. +struct TensorData { + TensorType type; + std::vector shape; + float min; + float max; + float scale; + int32_t zero_point; +}; + +class SingleOpModel { + public: + SingleOpModel() {} + ~SingleOpModel() {} + + // Copying or assignment is disallowed to simplify ownership semantics. + SingleOpModel(const SingleOpModel&) = delete; + SingleOpModel& operator=(const SingleOpModel&) = delete; + + // Add a TensorType input tensor and return its index. + int AddInput(TensorType type) { return AddInput(TensorData{type}); } + int AddInput(const TensorData& t); + + // Add a null input tensor (optional input) and return kOptionalTensor. + int AddNullInput(); + + // Add a TensorType output tensor and return its index. + int AddOutput(TensorType type) { return AddOutput(TensorData{type}); } + int AddOutput(const TensorData& t); + + template + void QuantizeAndPopulate(int index, std::initializer_list data) { + TfLiteTensor* t = interpreter_->tensor(index); + auto q = Quantize(data, t->params.scale, t->params.zero_point); + PopulateTensor(index, 0, q.data(), q.data() + q.size()); + } + + const std::vector& GetShape(int id) { return tensor_data_.at(id).shape; } + + float GetScale(int id) { return tensor_data_.at(id).scale; } + int32_t GetZeroPoint(int id) { return tensor_data_.at(id).zero_point; } + + // Define the operator in this model. + void SetBuiltinOp(BuiltinOperator type, BuiltinOptions builtin_options_type, + flatbuffers::Offset builtin_options); + void SetCustomOp(const string& name, + const std::vector& custom_option, + const std::function& registeration); + + // Build the interpreter for this model. Also, resize and allocate all + // tensors given the shapes of the inputs. + void BuildInterpreter(std::vector> input_shapes); + + void Invoke(); + + void PopulateStringTensor(int index, const std::vector& content) { + auto tensor = interpreter_->tensor(index); + DynamicBuffer buf; + for (const string& s : content) { + buf.AddString(s.data(), s.length()); + } + buf.WriteToTensor(tensor); + } + + // Populate the tensor given its index. + template + void PopulateTensor(int index, std::initializer_list data) { + T* v = interpreter_->typed_tensor(index); + CHECK(v) << "No tensor with index '" << index << "'."; + for (T f : data) { + *v = f; + ++v; + } + } + + // Partially populate the tensor, starting at the given offset. + template + void PopulateTensor(int index, int offset, T* begin, T* end) { + T* v = interpreter_->typed_tensor(index); + memcpy(v + offset, begin, (end - begin) * sizeof(T)); + } + + // Return a vector with the flattened contents of a tensor. + template + std::vector ExtractVector(int index) { + T* v = interpreter_->typed_tensor(index); + CHECK(v); + return std::vector(v, v + GetTensorSize(index)); + } + + std::vector GetTensorShape(int index) { + std::vector result; + TfLiteTensor* t = interpreter_->tensor(index); + for (int i = 0; i < t->dims->size; ++i) { + result.push_back(t->dims->data[i]); + } + return result; + } + + protected: + int32_t GetTensorSize(int index) const; + + flatbuffers::FlatBufferBuilder builder_; + std::unique_ptr interpreter_; + + private: + int AddTensor(TensorData t); + + std::map tensor_data_; + std::vector inputs_; + std::vector outputs_; + std::vector> tensors_; + std::vector> opcodes_; + std::vector> operators_; + std::map> custom_registrations_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc new file mode 100644 index 0000000000000000000000000000000000000000..de7a39b62c87bcf09ef577998b98f4fbcdb5b019 --- /dev/null +++ b/tensorflow/contrib/lite/model.cc @@ -0,0 +1,736 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { + +namespace { +inline const tflite::Model* VerifyAndGetModel(const void* buf, size_t len) { + ::flatbuffers::Verifier verifier(static_cast(buf), len); + if (VerifyModelBuffer(verifier)) { + return ::tflite::GetModel(buf); + } else { + return nullptr; + } +} +} // namespace + +const char* kEmptyTensorName = ""; + +std::unique_ptr FlatBufferModel::BuildFromFile( + const char* filename, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter, + /*use_nnapi=*/true)); + if (!model->initialized()) model.reset(); + return model; +} + +std::unique_ptr FlatBufferModel::BuildFromBuffer( + const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +std::unique_ptr FlatBufferModel::BuildFromModel( + const tflite::Model* model_spec, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(model_spec, error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, + ErrorReporter* error_reporter, bool use_nnapi) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + if (mmap_file) { + if (use_nnapi && NNAPIExists()) + allocation_ = new NNAPIAllocation(filename, error_reporter); + else + allocation_ = new MMAPAllocation(filename, error_reporter); + } else { + allocation_ = new FileCopyAllocation(filename, error_reporter); + } + if (!allocation_->valid()) return; + if (!CheckModelIdentifier()) return; + + model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); +} + +bool FlatBufferModel::CheckModelIdentifier() const { + if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { + const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); + error_reporter_->Report( + "Model provided has model identifier '%c%c%c%c', should be '%s'\n", + ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier()); + return false; + } + return true; +} + +FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, + ErrorReporter* error_reporter) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter); + if (!allocation_->valid()) return; + + model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); +} + +FlatBufferModel::FlatBufferModel(const Model* model, + ErrorReporter* error_reporter) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + model_ = model; +} + +FlatBufferModel::~FlatBufferModel() { delete allocation_; } + +InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model, + const OpResolver& op_resolver) + : model_(model.GetModel()), + op_resolver_(op_resolver), + error_reporter_(model.error_reporter()), + allocation_(model.allocation()) {} + +InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model, + const OpResolver& op_resolver, + ErrorReporter* error_reporter) + : model_(model), + op_resolver_(op_resolver), + error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) {} + +TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { + TfLiteStatus status = kTfLiteOk; + auto opcodes = model_->operator_codes(); + for (const OperatorCode* opcode : *opcodes) { + TfLiteRegistration* registration = nullptr; + + if (opcode->builtin_code() != BuiltinOperator_CUSTOM) { + auto x = opcode->builtin_code(); + flatbuffer_op_index_to_registration_types_.push_back(x); + registration = op_resolver_.FindOp(x); + if (registration == nullptr) { + error_reporter_->Report("Didn't find op for builtin opcode '%s'\n", + EnumNameBuiltinOperator(x)); + status = kTfLiteError; + } + } else if (!opcode->custom_code()) { + error_reporter_->Report( + "Operator with builtin_code==0 has no custom_code.\n"); + status = kTfLiteError; + } else { + const char* name = opcode->custom_code()->c_str(); + registration = op_resolver_.FindOp(name); + flatbuffer_op_index_to_registration_types_.push_back( + BuiltinOperator_CUSTOM); + if (registration == nullptr) { + error_reporter_->Report("Didn't find custom op for name '%s'\n", name); + status = kTfLiteError; + } + } + flatbuffer_op_index_to_registration_.push_back(registration); + } + return status; +} + +namespace { +template +std::vector FlatBufferIntArrayToVector(T* flat_array) { + std::vector ret(flat_array->Length()); + for (int i = 0; i < flat_array->Length(); i++) { + ret[i] = flat_array->Get(i); + } + return ret; +} + +// Copies the contents from the flatbuffer int vector `flatbuffer` into the +// int array `buffer`. `flat_vector` and `buffer` represent the same +// configuration operation for a given operation. +void FlatBufferIntVectorToArray(int max_size_of_buffer, + const flatbuffers::Vector* flat_vector, + int* buffer, ErrorReporter* error_reporter) { + if (!flat_vector) { + error_reporter->Report("Input array not provided for operation.\n"); + } else { + int num_dimensions = flat_vector->Length(); + if (num_dimensions > max_size_of_buffer / sizeof(int)) { + error_reporter->Report( + "Found too many dimensions in the operation's input array.\n"); + } else { + for (int i = 0; i < num_dimensions; ++i) { + buffer[i] = flat_vector->Get(i); + } + } + } +} + +// Allocate a structure using C malloc, but make sure the structure is a +// POD structure that doesn't require constructors to run. The reason we do +// this, is that Interpreter's C extension part will take ownership and wants +// to use malloc() and free(). +template +T* MallocPOD() { + static_assert(std::is_pod::value, "Builtin data structure must be POD."); + return static_cast(malloc(sizeof(T))); +} + +// Parse the appropriate data out of the op. +// +// This handles builtin data explicitly as there are flatbuffer schemas. +// +// Returns memory that must be feed. +// +// TODO(nupurgarg): Pass in void ** and return TfLiteStatus to ensure program +// crashes if error reporter is called. +void* ParseOpData(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter) { + auto parse_padding = [](Padding padding) { + switch (padding) { + case Padding_SAME: + return kTfLitePaddingSame; + case Padding_VALID: + return kTfLitePaddingValid; + } + return kTfLitePaddingUnknown; + }; + auto parse_activation = [](ActivationFunctionType activation) { + switch (activation) { + case ActivationFunctionType_NONE: + return kTfLiteActNone; + case ActivationFunctionType_RELU: + return kTfLiteActRelu; + case ActivationFunctionType_RELU1: + return kTfLiteActRelu1; + case ActivationFunctionType_RELU6: + return kTfLiteActRelu6; + case ActivationFunctionType_TANH: + return kTfLiteActTanh; + case ActivationFunctionType_SIGN_BIT: + return kTfLiteActSignBit; + } + return kTfLiteActNone; + }; + auto parseLSHProjectionType = [](LSHProjectionType type) { + switch (type) { + case LSHProjectionType_SPARSE: + return kTfLiteLshProjectionSparse; + case LSHProjectionType_DENSE: + return kTfLiteLshProjectionDense; + default: + return kTfLiteLshProjectionUnknown; + } + }; + auto parseCombinerType = [](CombinerType type) { + switch (type) { + case CombinerType_MEAN: + return kTfLiteCombinerTypeMean; + case CombinerType_SQRTN: + return kTfLiteCombinerTypeSqrtn; + case CombinerType_SUM: + default: + return kTfLiteCombinerTypeSum; + } + }; + + void* builtin_data = nullptr; + switch (op_type) { + case BuiltinOperator_CALL: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + break; + case BuiltinOperator_CUSTOM: + break; + case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_TANH: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU1: + case BuiltinOperator_RELU6: + case BuiltinOperator_CONCAT_EMBEDDINGS: + break; + case BuiltinOperator_LSH_PROJECTION: { + TfLiteLSHProjectionParams* params = + MallocPOD(); + if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { + params->type = parseLSHProjectionType(lshParams->type()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_AVERAGE_POOL_2D: + case BuiltinOperator_MAX_POOL_2D: + case BuiltinOperator_L2_POOL_2D: { + TfLitePoolParams* params = MallocPOD(); + if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { + params->padding = parse_padding(pool_params->padding()); + params->stride_width = pool_params->stride_w(); + params->stride_height = pool_params->stride_h(); + params->filter_width = pool_params->filter_width(); + params->filter_height = pool_params->filter_height(); + params->activation = + parse_activation(pool_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_DEPTHWISE_CONV_2D: { + TfLiteDepthwiseConvParams* params = + MallocPOD(); + if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->depth_multiplier = conv_params->depth_multiplier(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SVDF: { + TfLiteSVDFParams* params = MallocPOD(); + if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { + params->rank = svdf_params->rank(); + params->activation = + parse_activation(svdf_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RNN: { + TfLiteRNNParams* params = MallocPOD(); + if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { + params->activation = + parse_activation(rnn_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_EMBEDDING_LOOKUP: + // no-op. + break; + case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { + TfLiteEmbeddingLookupSparseParams* params = + MallocPOD(); + if (auto* embedding_params = + op->builtin_options_as_EmbeddingLookupSparseOptions()) { + params->combiner = parseCombinerType(embedding_params->combiner()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_FULLY_CONNECTED: { + TfLiteFullyConnectedParams* params = + MallocPOD(); + if (auto* fully_connected_params = + op->builtin_options_as_FullyConnectedOptions()) { + params->activation = parse_activation( + fully_connected_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_HASHTABLE_LOOKUP: + // no-op. + break; + case BuiltinOperator_SOFTMAX: { + TfLiteSoftmaxParams* params = MallocPOD(); + if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { + params->beta = softmax_params->beta(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_CONCATENATION: { + TfLiteConcatenationParams* params = + MallocPOD(); + if (auto* concatenation_params = + op->builtin_options_as_ConcatenationOptions()) { + params->activation = + parse_activation(concatenation_params->fused_activation_function()); + params->axis = concatenation_params->axis(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_MUL: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_MulOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_ADD: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_AddOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_L2_NORMALIZATION: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_LocalResponseNormalizationOptions()) { + params->radius = schema_params->radius(); + params->bias = schema_params->bias(); + params->alpha = schema_params->alpha(); + params->beta = schema_params->beta(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_LSTM: { + TfLiteLSTMParams* params = MallocPOD(); + if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { + params->activation = + parse_activation(lstm_params->fused_activation_function()); + params->cell_clip = lstm_params->cell_clip(); + params->proj_clip = lstm_params->proj_clip(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RESIZE_BILINEAR: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_ResizeBilinearOptions()) { + params->new_height = schema_params->new_height(); + params->new_width = schema_params->new_width(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_PAD: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_PadOptions()) { + auto* before_padding = schema_params->before_padding(); + FlatBufferIntVectorToArray(sizeof(params->before_padding), + before_padding, params->before_padding, + error_reporter); + + auto* after_padding = schema_params->after_padding(); + FlatBufferIntVectorToArray(sizeof(params->after_padding), after_padding, + params->after_padding, error_reporter); + + if (before_padding->Length() != after_padding->Length()) { + error_reporter->Report( + "Before padding and after padding arrays need to contain the " + "same number of dimensions.\n"); + } + params->num_dimensions = after_padding->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RESHAPE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { + auto* new_shape = schema_params->new_shape(); + FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, + params->shape, error_reporter); + params->num_dimensions = new_shape->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SKIP_GRAM: { + TfLiteSkipGramParams* params = MallocPOD(); + if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { + params->ngram_size = skip_gram_params->ngram_size(); + params->max_skip_size = skip_gram_params->max_skip_size(); + params->include_all_ngrams = skip_gram_params->include_all_ngrams(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SPACE_TO_DEPTH: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { + params->block_size = schema_params->block_size(); + } + builtin_data = reinterpret_cast(params); + break; + } + } + return builtin_data; +} + +} // namespace + +TfLiteStatus InterpreterBuilder::ParseNodes( + const flatbuffers::Vector>* operators, + Interpreter* interpreter) { + TfLiteStatus status = kTfLiteOk; + for (int i = 0; i < operators->Length(); ++i) { + const auto* op = operators->Get(i); + int index = op->opcode_index(); + if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) { + error_reporter_->Report("Missing registration for opcode_index %d\n", + index); + status = kTfLiteError; + continue; + } + const TfLiteRegistration* reg = + flatbuffer_op_index_to_registration_[op->opcode_index()]; + if (reg == nullptr) { + error_reporter_->Report("Skipping op for opcode_index %d\n", index); + status = kTfLiteError; + continue; + } + + auto op_type = + flatbuffer_op_index_to_registration_types_[op->opcode_index()]; + if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { + error_reporter_->Report( + "Found builtin operator %s with custom options.\n", + EnumNameBuiltinOperator(op_type)); + } + if (op->custom_options()) { + interpreter->AddNodeWithParameters( + FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), + reinterpret_cast(op->custom_options()->data()), + op->custom_options()->size(), nullptr, reg); + } else { + interpreter->AddNodeWithParameters( + FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, + ParseOpData(op, op_type, error_reporter_), reg); + } + } + + return status; +} + +TfLiteStatus InterpreterBuilder::ParseTensors( + const flatbuffers::Vector>* buffers, + const flatbuffers::Vector>* tensors, + Interpreter* interpreter) { + TfLiteStatus status = kTfLiteOk; + + // A little helper to get the names of inputs and outputs. Note that they + // must outlive the interpreter. + auto get_name = [](const tflite::Tensor* t) -> const char* { + auto name = t->name(); + if (name) return name->c_str(); + return kEmptyTensorName; + }; + + for (int i = 0; i < tensors->Length(); ++i) { + const auto* tensor = tensors->Get(i); + std::vector dims = FlatBufferIntArrayToVector(tensor->shape()); + + TfLiteQuantizationParams quantization; + quantization.scale = 0; + quantization.zero_point = 0; + auto* q_params = tensor->quantization(); + if (q_params) { + // Note that the schema could hold per-channel quantization parameters + // but we really only support one value for the whole tensor. + // TODO(aselle): This breaks as well if these are nullptr's. + // TODO(aselle): This assumes non per-channel quantization. + if (q_params->scale()) quantization.scale = q_params->scale()->Get(0); + if (q_params->zero_point()) + quantization.zero_point = q_params->zero_point()->Get(0); + } + + TfLiteType type; + switch (tensor->type()) { + case TensorType_FLOAT32: + type = kTfLiteFloat32; + break; + case TensorType_INT32: + type = kTfLiteInt32; + break; + case TensorType_UINT8: + type = kTfLiteUInt8; + break; + case TensorType_INT64: + type = kTfLiteInt64; + break; + case TensorType_STRING: + type = kTfLiteString; + break; + default: + // tensorType = ArrayType::NONE; + error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n", + EnumNameTensorType(tensor->type()), + tensor->type()); + status = kTfLiteError; + continue; + } + auto get_readonly_data = [&](const char** buffer_data, + size_t* buffer_size) { + // TODO(aselle): Check what happens if we have an unspecified size + // constant. + *buffer_data = nullptr; + if (tensor->buffer() == 0) return kTfLiteOk; + if (tensor->buffer() >= buffers->size()) { + error_reporter_->Report( + "Tensor %d specifies out of range buffer %d (only %d buffers).\n", + i, tensor->buffer(), buffers->size()); + return kTfLiteError; + } + if (auto* buffer = (*buffers)[tensor->buffer()]) { + if (auto* array = buffer->data()) { + if (size_t size = array->size()) { + *buffer_size = size; + *buffer_data = reinterpret_cast(array->data()); + return kTfLiteOk; + } + } + } + return kTfLiteOk; + }; + size_t buffer_size = 0; + const char* buffer_ptr; + TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); + + if (buffer_ptr) { + if (interpreter->SetTensorParametersReadOnly( + i, type, get_name(tensor), dims, quantization, buffer_ptr, + buffer_size, allocation_) != kTfLiteOk) { + error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", + i); + status = kTfLiteError; + } + } else { + if (interpreter->SetTensorParametersReadWrite( + i, type, get_name(tensor), dims, quantization) != kTfLiteOk) { + error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", + i); + status = kTfLiteError; + } + } + } + + return status; +} + +TfLiteStatus InterpreterBuilder::operator()( + std::unique_ptr* interpreter) { + if (!interpreter) { + error_reporter_->Report( + "Null output pointer passed to InterpreterBuilder."); + return kTfLiteError; + } + + // Safe exit by deleting partially created interpreter, to reduce verbosity + // on error conditions. Use by return cleanup_on_error(); + auto cleanup_and_error = [&interpreter]() { + interpreter->reset(); + return kTfLiteError; + }; + + if (!model_) { + error_reporter_->Report("Null pointer passed in as model."); + return cleanup_and_error(); + } + + if (model_->version() != TFLITE_SCHEMA_VERSION) { + error_reporter_->Report( + "Model provided is schema version %d not equal " + "to supported version %d.\n", + model_->version(), TFLITE_SCHEMA_VERSION); + return cleanup_and_error(); + } + + if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) { + error_reporter_->Report("Registration failed.\n"); + return cleanup_and_error(); + } + + // Flatbuffer model schemas define a list of opcodes independent of the graph. + // We first map those to registrations. This reduces string lookups for custom + // ops since we only do it once per custom op rather than once per custom op + // invocation in the model graph. + // Construct interpreter with correct number of tensors and operators. + auto* subgraphs = model_->subgraphs(); + auto* buffers = model_->buffers(); + if (subgraphs->size() != 1) { + error_reporter_->Report("Only 1 subgraph is currently supported.\n"); + return cleanup_and_error(); + } + const tflite::SubGraph* subgraph = (*subgraphs)[0]; + auto operators = subgraph->operators(); + auto tensors = subgraph->tensors(); + if (!operators || !tensors || !buffers) { + error_reporter_->Report( + "Did not get operators, tensors, or buffers in input flat buffer.\n"); + return cleanup_and_error(); + } + interpreter->reset(new Interpreter(error_reporter_)); + if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) { + return cleanup_and_error(); + } + + // Parse inputs/outputs + (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs())); + (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs())); + + // Finally setup nodes and tensors + if (ParseNodes(operators, interpreter->get()) != kTfLiteOk) + return cleanup_and_error(); + if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) + return cleanup_and_error(); + + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h new file mode 100644 index 0000000000000000000000000000000000000000..e0c96f7f0480cd3146f95a22957477809cf0096d --- /dev/null +++ b/tensorflow/contrib/lite/model.h @@ -0,0 +1,176 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Deserialization infrastructure for tflite. Provides functionality +// to go from a serialized tflite model in flatbuffer format to an +// interpreter. +// +// using namespace tflite; +// StderrReporter error_reporter; +// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite", +// &error_reporter); +// MyOpResolver resolver; // You need to subclass OpResolver to provide +// // implementations. +// InterpreterBuilder builder(*model, resolver); +// std::unique_ptr interpreter; +// if(builder(&interpreter) == kTfLiteOk) { +// .. run model inference with interpreter +// } +// +// OpResolver must be defined to provide your kernel implementations to the +// interpreter. This is environment specific and may consist of just the builtin +// ops, or some custom operators you defined to extend tflite. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ + +#include +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +// An RAII object that represents a read-only tflite model, copied from disk, +// or mmapped. This uses flatbuffers as the serialization format. +class FlatBufferModel { + public: + // Builds a model based on a file. Returns a nullptr in case of failure. + static std::unique_ptr BuildFromFile( + const char* filename, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Builds a model based on a pre-loaded flatbuffer. The caller retains + // ownership of the buffer and should keep it alive until the returned object + // is destroyed. Returns a nullptr in case of failure. + static std::unique_ptr BuildFromBuffer( + const char* buffer, size_t buffer_size, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Builds a model directly from a flatbuffer pointer. The caller retains + // ownership of the buffer and should keep it alive until the returned object + // is destroyed. Returns a nullptr in case of failure. + static std::unique_ptr BuildFromModel( + const tflite::Model* model_spec, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Releases memory or unmaps mmaped meory. + ~FlatBufferModel(); + + // Copying or assignment is disallowed to simplify ownership semantics. + FlatBufferModel(const FlatBufferModel&) = delete; + FlatBufferModel& operator=(const FlatBufferModel&) = delete; + + bool initialized() const { return model_ != nullptr; } + const tflite::Model* operator->() const { return model_; } + const tflite::Model* GetModel() const { return model_; } + ErrorReporter* error_reporter() const { return error_reporter_; } + const Allocation* allocation() const { return allocation_; } + + // Returns true if the model identifier is correct (otherwise false and + // reports an error). + bool CheckModelIdentifier() const; + + private: + // Loads a model from `filename`. If `mmap_file` is true then use mmap, + // otherwise make a copy of the model in a buffer. + // + // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be + // used. + explicit FlatBufferModel( + const char* filename, bool mmap_file = true, + ErrorReporter* error_reporter = DefaultErrorReporter(), + bool use_nnapi = false); + + // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has + // to remain alive and unchanged until the end of this flatbuffermodel's + // lifetime. + // + // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be + // used. + FlatBufferModel(const char* ptr, size_t num_bytes, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Loads a model from Model flatbuffer. The `model` has to remain alive and + // unchanged until the end of this flatbuffermodel's lifetime. + FlatBufferModel(const Model* model, ErrorReporter* error_reporter); + + // Flatbuffer traverser pointer. (Model* is a pointer that is within the + // allocated memory of the data allocated by allocation's internals. + const tflite::Model* model_ = nullptr; + ErrorReporter* error_reporter_; + Allocation* allocation_ = nullptr; +}; + +// Abstract interface that returns TfLiteRegistrations given op codes or custom +// op names. This is the mechanism that ops being referenced in the flatbuffer +// model are mapped to executable function pointers (TfLiteRegistrations). +class OpResolver { + public: + // Finds the op registration for a builtin operator by enum code. + virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; + // Finds the op registration of a custom operator by op name. + virtual TfLiteRegistration* FindOp(const char* op) const = 0; + virtual ~OpResolver() {} +}; + +// Build an interpreter capable of interpreting `model`. +// +// model: a scoped model whose lifetime must be at least as long as +// the interpreter. In principle multiple interpreters can be made from +// a single model. +// op_resolver: An instance that implements the Resolver interface which maps +// custom op names and builtin op codes to op registrations. +// reportError: a functor that is called to report errors that handles +// printf var arg semantics. The lifetime of the reportError object must +// be greater than or equal to the Interpreter created by operator(). +// +// Returns a kTfLiteOk when successful and sets interpreter to a valid +// Interpreter. Note: the user must ensure the model lifetime is at least as +// long as interpreter's lifetime. +class InterpreterBuilder { + public: + InterpreterBuilder(const FlatBufferModel& model, + const OpResolver& op_resolver); + // Builds an interpreter given only the raw flatbuffer Model object (instead + // of a FlatBufferModel). Mostly used for testing. + // If `error_reporter` is null, then DefaultErrorReporter() is used. + InterpreterBuilder(const ::tflite::Model* model, + const OpResolver& op_resolver, + ErrorReporter* error_reporter = DefaultErrorReporter()); + InterpreterBuilder(const InterpreterBuilder&) = delete; + InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; + TfLiteStatus operator()(std::unique_ptr* interpreter); + + private: + TfLiteStatus BuildLocalIndexToRegistrationMapping(); + TfLiteStatus ParseNodes( + const flatbuffers::Vector>* operators, + Interpreter* interpreter); + TfLiteStatus ParseTensors( + const flatbuffers::Vector>* buffers, + const flatbuffers::Vector>* tensors, + Interpreter* interpreter); + + const ::tflite::Model* model_; + const OpResolver& op_resolver_; + ErrorReporter* error_reporter_; + + std::vector flatbuffer_op_index_to_registration_; + std::vector flatbuffer_op_index_to_registration_types_; + const Allocation* allocation_ = nullptr; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5330c8f594593655b2a8776cf6b399c0d16cdc19 --- /dev/null +++ b/tensorflow/contrib/lite/model_test.cc @@ -0,0 +1,290 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/model.h" + +#include +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/testing/util.h" + +// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, +// we must declare this in global namespace, so argument-dependent operator +// lookup works. +inline bool operator==(const TfLiteRegistration& a, + const TfLiteRegistration& b) { + return a.invoke == b.invoke && a.init == b.init && a.prepare == b.prepare && + a.free == b.free; +} + +namespace tflite { + +// Provide a dummy operation that does nothing. +namespace { +void* dummy_init(TfLiteContext*, const char*, size_t) { return nullptr; } +void dummy_free(TfLiteContext*, void*) {} +TfLiteStatus dummy_resize(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } +TfLiteStatus dummy_invoke(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } +TfLiteRegistration dummy_reg = {dummy_init, dummy_free, dummy_resize, + dummy_invoke}; +} // namespace + +// Provide a trivial resolver that returns a constant value no matter what +// op is asked for. +class TrivialResolver : public OpResolver { + public: + explicit TrivialResolver(TfLiteRegistration* constant_return = nullptr) + : constant_return_(constant_return) {} + // Find the op registration of a custom operator by op name. + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override { + return constant_return_; + } + // Find the op registration of a custom operator by op name. + TfLiteRegistration* FindOp(const char* op) const override { + return constant_return_; + } + + private: + TfLiteRegistration* constant_return_; +}; + +TEST(BasicFlatBufferModel, TestNonExistantFiles) { + ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234")); +} + +// Make sure a model with nothing in it loads properly. +TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin"); + ASSERT_TRUE(model); + // Now try to build it into a model. + std::unique_ptr interpreter; + ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + ASSERT_NE(InterpreterBuilder(*model, TrivialResolver())(nullptr), kTfLiteOk); +} + +// Make sure currently unsupported # of subgraphs are checked +// TODO(aselle): Replace this test when multiple subgraphs are supported. +TEST(BasicFlatBufferModel, TestZeroAndMultipleSubgraphs) { + auto m1 = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/0_subgraphs.bin"); + ASSERT_TRUE(m1); + std::unique_ptr interpreter1; + ASSERT_NE(InterpreterBuilder(*m1, TrivialResolver())(&interpreter1), + kTfLiteOk); + + auto m2 = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/2_subgraphs.bin"); + ASSERT_TRUE(m2); + std::unique_ptr interpreter2; + ASSERT_NE(InterpreterBuilder(*m2, TrivialResolver())(&interpreter2), + kTfLiteOk); +} + +// Test what happens if we cannot bind any of the ops. +TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin"); + ASSERT_TRUE(model); + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_NE(InterpreterBuilder(*model, TrivialResolver(nullptr))(&interpreter), + kTfLiteOk); + ASSERT_EQ(interpreter, nullptr); +} + +// Make sure model is read to interpreter propelrly +TEST(BasicFlatBufferModel, TestModelInInterpreter) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin"); + ASSERT_TRUE(model); + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_EQ( + InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(interpreter->tensors_size(), 4); + ASSERT_EQ(interpreter->nodes_size(), 2); + std::vector inputs = {0, 1}; + std::vector outputs = {2, 3}; + ASSERT_EQ(interpreter->inputs(), inputs); + ASSERT_EQ(interpreter->outputs(), outputs); + + EXPECT_EQ(std::string(interpreter->GetInputName(0)), "input0"); + EXPECT_EQ(std::string(interpreter->GetInputName(1)), "input1"); + EXPECT_EQ(std::string(interpreter->GetOutputName(0)), "out1"); + EXPECT_EQ(std::string(interpreter->GetOutputName(1)), "out2"); + + // Make sure all input tensors are correct + TfLiteTensor* i0 = interpreter->tensor(0); + ASSERT_EQ(i0->type, kTfLiteFloat32); + ASSERT_NE(i0->data.raw, nullptr); // mmapped + ASSERT_EQ(i0->allocation_type, kTfLiteMmapRo); + TfLiteTensor* i1 = interpreter->tensor(1); + ASSERT_EQ(i1->type, kTfLiteFloat32); + ASSERT_EQ(i1->data.raw, nullptr); + ASSERT_EQ(i1->allocation_type, kTfLiteArenaRw); + TfLiteTensor* o0 = interpreter->tensor(2); + ASSERT_EQ(o0->type, kTfLiteFloat32); + ASSERT_EQ(o0->data.raw, nullptr); + ASSERT_EQ(o0->allocation_type, kTfLiteArenaRw); + TfLiteTensor* o1 = interpreter->tensor(3); + ASSERT_EQ(o1->type, kTfLiteFloat32); + ASSERT_EQ(o1->data.raw, nullptr); + ASSERT_EQ(o1->allocation_type, kTfLiteArenaRw); + + // Check op 0 which has inputs {0, 1} outputs {2}. + { + const std::pair* node_and_reg0 = + interpreter->node_and_registration(0); + ASSERT_NE(node_and_reg0, nullptr); + const TfLiteNode& node0 = node_and_reg0->first; + const TfLiteRegistration& reg0 = node_and_reg0->second; + TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(2); + desired_inputs->data[0] = 0; + desired_inputs->data[1] = 1; + TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1); + desired_outputs->data[0] = 2; + ASSERT_TRUE(TfLiteIntArrayEqual(node0.inputs, desired_inputs)); + ASSERT_TRUE(TfLiteIntArrayEqual(node0.outputs, desired_outputs)); + TfLiteIntArrayFree(desired_inputs); + TfLiteIntArrayFree(desired_outputs); + ASSERT_EQ(reg0, dummy_reg); + } + + // Check op 1 which has inputs {2} outputs {3}. + { + const std::pair* node_and_reg1 = + interpreter->node_and_registration(1); + ASSERT_NE(node_and_reg1, nullptr); + const TfLiteNode& node1 = node_and_reg1->first; + const TfLiteRegistration& reg1 = node_and_reg1->second; + TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(1); + TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1); + desired_inputs->data[0] = 2; + desired_outputs->data[0] = 3; + ASSERT_TRUE(TfLiteIntArrayEqual(node1.inputs, desired_inputs)); + ASSERT_TRUE(TfLiteIntArrayEqual(node1.outputs, desired_outputs)); + TfLiteIntArrayFree(desired_inputs); + TfLiteIntArrayFree(desired_outputs); + ASSERT_EQ(reg1, dummy_reg); + } +} + +// This tests on a flatbuffer that defines a shape of 2 to be a memory mapped +// buffer. But the buffer is provided to be only 1 element. +TEST(BasicFlatBufferModel, TestBrokenMmap) { + ASSERT_FALSE(FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model_broken.bin")); +} + +TEST(BasicFlatBufferModel, TestNullModel) { + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_NE( + InterpreterBuilder(nullptr, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_EQ(interpreter.get(), nullptr); +} + +struct TestErrorReporter : public ErrorReporter { + int Report(const char* format, va_list args) override { + calls++; + return 0; + } + int calls = 0; +}; + +// This makes sure the ErrorReporter is marshalled from FlatBufferModel to +// the Interpreter. +TEST(BasicFlatBufferModel, TestCustomErrorReporter) { + TestErrorReporter reporter; + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin", + &reporter); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + TrivialResolver resolver; + InterpreterBuilder(*model, resolver)(&interpreter); + ASSERT_NE(interpreter->Invoke(), kTfLiteOk); + ASSERT_EQ(reporter.calls, 1); +} + +// This makes sure the ErrorReporter is marshalled from FlatBufferModel to +// the Interpreter. +TEST(BasicFlatBufferModel, TestNullErrorReporter) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin", nullptr); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + TrivialResolver resolver; + InterpreterBuilder(*model, resolver)(&interpreter); + ASSERT_NE(interpreter->Invoke(), kTfLiteOk); +} + +// Test what happens if we cannot bind any of the ops. +TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) { + std::string corrupted_data = "123"; + auto model = FlatBufferModel::BuildFromBuffer(corrupted_data.c_str(), + corrupted_data.length()); + ASSERT_FALSE(model); +} + +// Test that loading model directly from a Model flatbuffer works. +TEST(BasicFlatBufferModel, TestBuildFromModel) { + TestErrorReporter reporter; + FileCopyAllocation model_allocation( + "tensorflow/contrib/lite/testdata/test_model.bin", &reporter); + ASSERT_TRUE(model_allocation.valid()); + ::flatbuffers::Verifier verifier( + reinterpret_cast(model_allocation.base()), + model_allocation.bytes()); + ASSERT_TRUE(VerifyModelBuffer(verifier)); + const Model* model_fb = ::tflite::GetModel(model_allocation.base()); + + auto model = FlatBufferModel::BuildFromModel(model_fb); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + ASSERT_EQ( + InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); +} + +// TODO(aselle): Add tests for serialization of builtin op data types. +// These tests will occur with the evaluation tests of individual operators, +// not here. + +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..733c3f4c7fa0605f24a1e6b4c458e34310c079c4 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -0,0 +1,100 @@ +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") + +licenses(["notice"]) # Apache 2.0 + +gen_selected_ops( + name = "smartreply_ops", + model = "@tflite_smartreply//:smartreply.tflite", +) + +cc_library( + name = "custom_ops", + srcs = [ + "ops/extract_feature.cc", + "ops/normalize.cc", + "ops/predict.cc", + ":smartreply_ops", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + "@farmhash_archive//:farmhash", + ], +) + +cc_library( + name = "predictor_lib", + srcs = ["predictor.cc"], + hdrs = ["predictor.h"], + copts = tflite_copts(), + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "extract_feature_op_test", + size = "small", + srcs = ["ops/extract_feature_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@farmhash_archive//:farmhash", + ], +) + +cc_test( + name = "normalize_op_test", + size = "small", + srcs = ["ops/normalize_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "predict_op_test", + size = "small", + srcs = ["ops/predict_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..75ed9432c8fcdfd77a64d3c659e6336c977cdda2 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f8767b443a2aa64b666c3b6bfb7db30cc0be62ea --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -0,0 +1,65 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "tflite_copts", + "tflite_jni_binary", +) + +filegroup( + name = "assets", + srcs = [ + "@tflite_smartreply//:model_files", + ], +) + +android_binary( + name = "SmartReplyDemo", + srcs = glob(["java/**/*.java"]), + assets = [":assets"], + assets_dir = "", + custom_package = "com.example.android.smartreply", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = ["manual"], + deps = [ + ":smartreply_runtime", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) + +cc_library( + name = "smartreply_runtime", + srcs = ["libsmartreply_jni.so"], + visibility = ["//visibility:public"], +) + +tflite_jni_binary( + name = "libsmartreply_jni.so", + deps = [ + ":smartreply_jni_lib", + ], +) + +cc_library( + name = "smartreply_jni_lib", + srcs = [ + "smartreply_jni.cc", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/models/smartreply:predictor_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3c882ffc43fde577801428151a43b592e8faaed1 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(glob(["*"])) + +filegroup( + name = "assets_files", + srcs = glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt new file mode 100644 index 0000000000000000000000000000000000000000..a0a5b46b5f8d5fd6a0297c8056bb2fb9b6ad9ada --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt @@ -0,0 +1,16 @@ +Ok +Yes +No +👍 +☺ +😟 +❤️ +Lol +Thanks +Got it +Done +Nice +I don't know +What? +Why? +What's up? diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..02fec9ae5e971ad756ae6c2b0149a6aacfa27cad --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.app.Activity; +import android.os.Bundle; +import android.os.Handler; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.EditText; +import android.widget.TextView; + +/** + * The main (and only) activity of this demo app. Displays a text box which updates as messages are + * received. + */ +public class MainActivity extends Activity { + private static final String TAG = "SmartReplyDemo"; + private SmartReplyClient client; + + private Button sendButton; + private TextView messageTextView; + private EditText messageInput; + + private Handler handler; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + Log.v(TAG, "onCreate"); + setContentView(R.layout.main_activity); + + client = new SmartReplyClient(getApplicationContext()); + handler = new Handler(); + + sendButton = (Button) findViewById(R.id.send_button); + sendButton.setOnClickListener( + (View v) -> { + send(messageInput.getText().toString()); + }); + + messageTextView = (TextView) findViewById(R.id.message_text); + messageInput = (EditText) findViewById(R.id.message_input); + } + + @Override + protected void onStart() { + super.onStart(); + Log.v(TAG, "onStart"); + handler.post( + () -> { + client.loadModel(); + }); + } + + @Override + protected void onStop() { + super.onStop(); + Log.v(TAG, "onStop"); + handler.post( + () -> { + client.unloadModel(); + }); + } + + private void send(final String message) { + handler.post( + () -> { + messageTextView.append("Input: " + message + "\n"); + + SmartReply[] ans = client.predict(new String[] {message}); + for (SmartReply reply : ans) { + appendMessage("Reply: " + reply.getText()); + } + appendMessage("------"); + }); + } + + private void appendMessage(final String message) { + handler.post( + () -> { + messageTextView.append(message + "\n"); + }); + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java new file mode 100644 index 0000000000000000000000000000000000000000..3357fd17c11f870d1b0998bb26ffa9abf149686b --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.support.annotation.Keep; + +/** + * SmartReply contains predicted message, and confidence. + * + *

NOTE: this class used by JNI, class name and constructor should not be obfuscated. + */ +@Keep +public class SmartReply { + + private final String text; + private final float score; + + @Keep + public SmartReply(String text, float score) { + this.text = text; + this.score = score; + } + + public String getText() { + return text; + } + + public float getScore() { + return score; + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java new file mode 100644 index 0000000000000000000000000000000000000000..d5b1ac0ffbc47283aa0c1bf68c0a85ad6228cdcc --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.support.annotation.Keep; +import android.support.annotation.WorkerThread; +import android.util.Log; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.List; + +/** Interface to load TfLite model and provide predictions. */ +public class SmartReplyClient implements AutoCloseable { + private static final String TAG = "SmartReplyDemo"; + private static final String MODEL_PATH = "smartreply.tflite"; + private static final String BACKOFF_PATH = "backoff_response.txt"; + private static final String JNI_LIB = "smartreply_jni"; + + private final Context context; + private long storage; + private MappedByteBuffer model; + + private volatile boolean isLibraryLoaded; + + public SmartReplyClient(Context context) { + this.context = context; + } + + public boolean isLoaded() { + return storage != 0; + } + + @WorkerThread + public synchronized void loadModel() { + if (!isLibraryLoaded) { + System.loadLibrary(JNI_LIB); + isLibraryLoaded = true; + } + + try { + model = loadModelFile(); + String[] backoff = loadBackoffList(); + storage = loadJNI(model, backoff); + } catch (IOException e) { + Log.e(TAG, "Fail to load model", e); + return; + } + } + + @WorkerThread + public synchronized SmartReply[] predict(String[] input) { + if (storage != 0) { + return predictJNI(storage, input); + } else { + return new SmartReply[] {}; + } + } + + @WorkerThread + public synchronized void unloadModel() { + close(); + } + + @Override + public synchronized void close() { + if (storage != 0) { + unloadJNI(storage); + storage = 0; + } + } + + private MappedByteBuffer loadModelFile() throws IOException { + AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + try { + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } finally { + inputStream.close(); + } + } + + private String[] loadBackoffList() throws IOException { + List labelList = new ArrayList(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH))); + String line; + while ((line = reader.readLine()) != null) { + if (!line.isEmpty()) { + labelList.add(line); + } + } + reader.close(); + String[] ans = new String[labelList.size()]; + labelList.toArray(ans); + return ans; + } + + @Keep + private native long loadJNI(MappedByteBuffer buffer, String[] backoff); + + @Keep + private native SmartReply[] predictJNI(long storage, String[] text); + + @Keep + private native void unloadJNI(long storage); +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml new file mode 100644 index 0000000000000000000000000000000000000000..23b4cadc007a4457d33b8c8fecf9b1e7b7436320 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml @@ -0,0 +1,44 @@ + + + + + + + + + + +