Commit 67f0b8f9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TensorFlower Gardener
Browse files

Enabling partitioned variables to work with TPU.

When partitioned variables are used in a TPU training loop,
concat gradient operations get generated for which XLA requires
the concat dimension argument to be a constant (or foldable to a constant).
However since such constant is defined outside of the train while context
an Enter node is generated in order to pass it.
The fix consists in detecting such case, and to duplicate the (scalar) constant
inside the while context, so that XLA can succesfully process the resulting
graph.

PiperOrigin-RevId: 184273245
parent ffa37312
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment