Support multiple loss and multiple optimizers in replicate_model_fn.
Instead of supplying `optimizer_fn`, the user is now expected to wrap their optimizer in GatheringOptimizer. The latter will gather gradients, reduce and apply them. There can be multiple instances of GatheringOptimizer inside the model. PiperOrigin-RevId: 179899422
Loading
Please sign in to comment