Add `batch_scatter_update`, analogous to `batch_gather`.
This operation computes: ref[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j] That is, it assumes that `ref`, `indices` and `updates` have a series of leading dimensions that are the same for all of them, and the updates are performed on the last dimension of indices. PiperOrigin-RevId: 209566652
Loading
Please sign in to comment