Commit 7bf0d2af authored by Guangda Lai's avatar Guangda Lai Committed by TensorFlower Gardener
Browse files

Refactor TF-TRT C++ API in convert_nodes.cc, and add TrtNodeValidator class

(which will be later used to do compile-time validation for each TF node to
determine whether it's supported by TRT, in cl/217561555).

Most of the changes are not functional changes but are parameter plumbings,
they come from the 3rd item listed below.

Implementation details:

- Added TrtNodeValidator to validate nodes at compile-time to determine whether
  they're TRT compatible. Validation logic is added for ConvertTranspose and
  ConvertReshape, but the real validation will happen in a separate CL.
- Split APIs to manage the storage of TRT_ShapedWeights to a separate helper
  class TrtWeightStore. Both TrtNodeValidator and Converter need that.
- Change signature of the op converters, by putting all arguments into single
  OpConverterParams struct (***this is where most of the changes come from***).
  Also we added TrtWeightStore separately in the struct, and added
  validation_only for used by TrtNodeValidator.

PiperOrigin-RevId: 218923091
parent f4e696d5
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment