Replicate `Estimator.model_fn` across available GPUs.
def replicate_model_fn(model_fn, optimizer_fn, devices=None):
"""Replicate `Estimator.model_fn` over GPUs.
...
I tested that it seems to give the right result on cnn_mnist.py on 1 CPU, 1 real GPU, 4 allow_soft_placement=True GPUs.
Some measurements on CNN MNIST across steps 19300-20000:
1) no replicate_model_fn call:
global_step/sec: 156.254
global_step/sec: 155.074
global_step/sec: 155.74
global_step/sec: 153.636
global_step/sec: 157.218
global_step/sec: 159.644
2) replicate across one hardware GPU:
global_step/sec: 158.171
global_step/sec: 165.618
global_step/sec: 162.773
global_step/sec: 159.204
global_step/sec: 162.289
global_step/sec: 167.173
3) replicate across 4 software GPUs on one hardware GPU (soft placement):
global_step/sec: 75.47
global_step/sec: 76.16
global_step/sec: 75.18
Loss numbers didn't change across the three configurations.
PiperOrigin-RevId: 174704385
Loading
Please sign in to comment