Register int types for one_hot GPU ops.
Update beam_search_decoder internal int32 types to int64 in order to compute length penalty on GPU. This speedup the NMT beam search decoding (beam_width=10) with length penalty by ~6.8 times. PiperOrigin-RevId: 169720804
Loading
Please sign in to comment