Fix dataset resampling bug introduced by a bug in datasets itself.
Fixes github issue #16606. The core issue is that in the case of certain random Tensors, the following two lines aren't the same: ``` rand_0s_and_1s_ds = ... gather_ds = rand_0s_and_1s_ds.map(lambda i: tf.gather([0, 1], i)) tup_ds = tf.data.Dataset.zip(gather_ds, rand_0s_and_1s_ds) ``` ``` rand_0s_and_1s_ds = ... tup_ds = rand_0s_and_1s_ds.map(lambda i: (tf.gather([0, 1], i), i)) ``` Note that this does NOT fix the underlying issue of drawing multiple sampes from the underlying distribution. Tested: With the new test, `bazel test :resample_test` fails before and succeeds after.
Loading
Please sign in to comment