Commit cef42c64 authored by Rui Zhao's avatar Rui Zhao Committed by TensorFlower Gardener
Browse files

Register int types for one_hot GPU ops.

Update beam_search_decoder internal int32 types to int64 in order to compute
length penalty on GPU.

This speedup the NMT beam search decoding (beam_width=10) with length penalty by ~6.8 times.

PiperOrigin-RevId: 169720804
parent d9bd87b1
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment