[XLA:GPU] Don't crash when the root instruction of a computation is a...
[XLA:GPU] Don't crash when the root instruction of a computation is a multi-output fusion node, and avoid some pointer chasing with tuples. Previously, the kernels we generated would have one argument per *top-level* buffer of the input/output. This was fine for inputs. But it doesn't work for outputs: Imagine you're a node that returns a tuple -- e.g. multi-output fusion -- if all you get is a pointer to the top-level buffer of your output (which should contain pointers to the lower-level buffers at some point, but at the moment is just empty), how are you supposed to figure out where to write your output? (This usually worked because most of the time your output would live inside of the big XLA temp buffer, and kernels always get a pointer to that.) Now we pass all the buffers, top-level and otherwise, to our kernel. In addition, we're now willing to dereference statically tuples that live entirely in XLA's temp buffer. Pointers in input tuples must still be dereferenced dynamically, because the caller has the option of giving us these values or not when invoking XLA. This change makes some parts of BufferAssignment/BufferAllocations more truthful. Previously, if you passed a tuple-shaped input to XLA, we'd say in BufferAllocations that the pointer for some subshape of the param was the *top-level tuple pointer*. XLA then knew that this was a lie and would dereference it accordingly. Now we have an explicit notion of a BufferAllocation pointing to a subshape of an input parameter. PiperOrigin-RevId: 185614060
Loading
Please sign in to comment