Commit 5ed41f6d authored by Guillaume Klein's avatar Guillaume Klein
Browse files

Fix masking of beam ids in gather_tree_from_array

The `sequence_length` argument that is passed to the function is the
lengths of the **reordered** predictions and was incorrectly used to
mask beam ids *before* reordering. Instead, we can reorder beam ids
without caring about out of range steps and only select the reodered
ids that are in bounds.

The added test covers a beam trajectory that previously produced an
out of range error because `gather_tree` returned `end_token` (here
`beam_width + 1`) for some steps.
parent 10aacf16
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment