Add batch_dims argument to tf.gather.
This provides a generalization of the behavior provided by tf.batch_gather, but is backwards-compatible with the existing tf.gather API. In particular: * tf.gather(..., batch_dims=0) is equivalent to tf.gather(...) * tf.gather(..., batch_dims=indices.shape.ndims-1) is equivalent to tf.batch_gather(...) In addition, the new batch_dims parameter can be used to specify a number of batch dimensions between these two extremes. E.g., in the following example, we use batch_dims=1 to indicate that only the outermost dimension is a batch dimension: >>> tf.batch_gather( ... params=[["a", "b", "c", "d"], ["e", "f", "g", "h"]], ... indices=[[2, 1], [0, 3]], ... batch_dims=1) [["c", "b"], ["e", "h"]] PiperOrigin-RevId: 226426563
Loading
Please sign in to comment