Enabling partitioned variables to work with TPU.
When partitioned variables are used in a TPU training loop, concat gradient operations get generated for which XLA requires the concat dimension argument to be a constant (or foldable to a constant). However since such constant is defined outside of the train while context an Enter node is generated in order to pass it. The fix consists in detecting such case, and to duplicate the (scalar) constant inside the while context, so that XLA can succesfully process the resulting graph. PiperOrigin-RevId: 184273245
Loading
Please sign in to comment