Commit 62678265 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TensorFlower Gardener
Browse files

Add batch_dims argument to tf.gather.

This provides a generalization of the behavior provided by tf.batch_gather, but is backwards-compatible with the existing tf.gather API.  In particular:

* tf.gather(..., batch_dims=0) is equivalent to tf.gather(...)
* tf.gather(..., batch_dims=indices.shape.ndims-1) is equivalent to tf.batch_gather(...)

In addition, the new batch_dims parameter can be used to specify a number of batch dimensions between these two extremes.  E.g., in the following example, we use batch_dims=1 to indicate that only the outermost dimension is a batch dimension:

  >>> tf.batch_gather(
  ...     params=[["a", "b", "c", "d"], ["e", "f", "g", "h"]],
  ...     indices=[[2, 1], [0, 3]],
  ...     batch_dims=1)
  [["c", "b"], ["e", "h"]]

PiperOrigin-RevId: 226426563
parent 83dbe37b
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment