Commit b1f9e2c8 authored by RJ Ryan's avatar RJ Ryan Committed by TensorFlower Gardener
Browse files

Add an axis parameter to tf.gather. Fixes GitHub issue #11223.

This brings tf.gather closer to compatibility with numpy.take.

To emulate gathering over an axis generally requires inefficient workarounds, e.g. transpose/gather/transpose. This technique is gaining popularity (hundreds of uses inside and outside of Google), so it is worth supporting efficiently.

For an `[a_0, ..., a_i, ..., a_n]` tensor, gathering `N` elements from axis `i` requires `(a_0*...*a_i-1) * N` copies of `(a_i+1 * ... * a_n)` elements each. The CPU kernel does this with memcpy which is far more efficient than transpose/gather/transpose since it requires no intermediate allocations and copies. The GPU kernel does the same number of copies but in parallel across multiple hardware threads.

Since this is a backwards incompatible change, this adds a "GatherV2" op with an axis input, and simultaneously supports backwards compatibility with "Gather" ops by defaulting to axis 0 if a 3rd input is not present.

PiperOrigin-RevId: 161541416
parent 18a5510e
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment