Fix masking of beam ids in gather_tree_from_array
The `sequence_length` argument that is passed to the function is the lengths of the **reordered** predictions and was incorrectly used to mask beam ids *before* reordering. Instead, we can reorder beam ids without caring about out of range steps and only select the reodered ids that are in bounds. The added test covers a beam trajectory that previously produced an out of range error because `gather_tree` returned `end_token` (here `beam_width + 1`) for some steps.
Loading
Please sign in to comment