Fix fp16 BatchNormalization when fused=False.
Also fix BatchNormalization fp16 test. Before, the test was actually running in fp32, because the Keras would cast the input to the model to the input layer's dtype, which defaults to fp32. PiperOrigin-RevId: 236133280
Loading
Please sign in to comment