Commit 87d558f2 authored by Scott Zhu's avatar Scott Zhu Committed by TensorFlower Gardener
Browse files

Fix random flakiness in AttentionWrapperTest.

Previously the output projection dense layer was initialized to be ones, which then cause the data to sampler.sample() to be same within batch, which is why the expected_sample_id mean() to be 0 (it always pick the first index since all the data among the index are the same). Due to the unknown dependency difference between OSS and internal build, sometimes the data will loss some precision, and cause the sample_id to change randomly. This change update the output dense layer to have random weights, which will then lead to much diverse but determined result.

PiperOrigin-RevId: 237534170
parent 29a6b5ec
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment