Fix support for batch_normalization with mixed precision
When the type of the input tensor `x` is not the same as the type of the parameters `mean`, `variance`, `offset`, and `scale`, a cast is required. This mixed precision case occurs when using the BatchNormalization layer with a data type of float16 or bfloat16. PiperOrigin-RevId: 195157279
Loading
Please sign in to comment