Fix support for seq2seq with mixed precision
When the type of the input tensor `x` is not the same as the type of the hidden states cast is required. This mixed precision case occurs when using the seq2seq layer with a data type of float16 or bfloat16. PiperOrigin-RevId: 204364209
Loading
Please sign in to comment