Add functionality to fold batch norm (supporting both fused and unfused batch...
Add functionality to fold batch norm (supporting both fused and unfused batch norm) to support quantized training. The weights are always now scaled by gamma/sigma, where sigma is the moving standard deviation for stability prior to quantization. For improved performance, the moving means and variances are frozen and the training graph modified accordingly. An additional parameter freeze_batch_norm_delay is added to foldbatchnorm function to set the delay at which training switches from regular batch norm to frozen mean and variances. Remove placement options within FoldBatchNorm as this causes folded training to place all ops on a single GPU. The modification now significantly speeds up distributed training. The tests for folding batch norms are also updated to reflect the additional topological changes to the graph. PiperOrigin-RevId: 184211434
Loading
Please sign in to comment