Fix random flakiness in AttentionWrapperTest.
Previously the output projection dense layer was initialized to be ones, which then cause the data to sampler.sample() to be same within batch, which is why the expected_sample_id mean() to be 0 (it always pick the first index since all the data among the index are the same). Due to the unknown dependency difference between OSS and internal build, sometimes the data will loss some precision, and cause the sample_id to change randomly. This change update the output dense layer to have random weights, which will then lead to much diverse but determined result. PiperOrigin-RevId: 237534170
Loading
Please sign in to comment