In TPUEstimator, allow input_fn to return features and labels of any arbitrary...
In TPUEstimator, allow input_fn to return features and labels of any arbitrary nested structure rather than just tensor or dict or tensors. The current restriction seems to be historical. The code in _ModelFnWrapper uses nest flattens already, and then in _call_model_fn, 'features' and 'labels' are passed directly to the user code model_fn, so any nested structure should be supported. Before entering ModelFnWrapper, the _InputPipeline and _Inputs classes also both handle nested structure as well. PiperOrigin-RevId: 217010800
Loading
Please sign in to comment