Fix TF/XLA lowering for FusedBatchnorm training on GPU
The FusedBatchnorm TensorFlow operation has two outputs, reserve_space_{1|2},
that contain unspecified results; in practice on CPUs the they contain the mean
and the variance and on GPUs they contain the mean the *inverse sqrt of the
variance* (what cudnn returns). This is unfortunate for XLA because without a
spec describing what these outputs are, we can't soundly replace these outputs
with XLA computed values.
For now we rely on the in-practice values returned in reserve_space_{1|2} and
add a test case checking that these values are indeed correct. In the future we
may consider some of the cleaner fixes outlined in the tracking bug.
PiperOrigin-RevId: 232050515
Loading
Please sign in to comment