Average the loss across `replicate_model_fn`'s towers.
This avoids the need for users to add `loss = loss / num_of_towers` code and is in more in line with the current best practices. I verified this by running cnn_mnist. PiperOrigin-RevId: 178963334
Loading
Please sign in to comment