Add an axis parameter to tf.gather. Fixes GitHub issue #11223.
This brings tf.gather closer to compatibility with numpy.take. To emulate gathering over an axis generally requires inefficient workarounds, e.g. transpose/gather/transpose. This technique is gaining popularity (hundreds of uses inside and outside of Google), so it is worth supporting efficiently. For an `[a_0, ..., a_i, ..., a_n]` tensor, gathering `N` elements from axis `i` requires `(a_0*...*a_i-1) * N` copies of `(a_i+1 * ... * a_n)` elements each. The CPU kernel does this with memcpy which is far more efficient than transpose/gather/transpose since it requires no intermediate allocations and copies. The GPU kernel does the same number of copies but in parallel across multiple hardware threads. Since this is a backwards incompatible change, this adds a "GatherV2" op with an axis input, and simultaneously supports backwards compatibility with "Gather" ops by defaulting to axis 0 if a 3rd input is not present. PiperOrigin-RevId: 161541416
Loading
Please sign in to comment