[XLA:GPU] Enhance column reduction implementation.
With KernelMappingScheme, a thread processes n elements, where n=tile_size_x/num_thread_x. Previously, we only support n>1 for transpose and row reduction. This change extends the kernel mapping scheme and code generation to allow each thread to compute n partial results for column reduction as follows: .Add dilated_x to KernelMappingScheme, to indicate whether the multiple elements processed by the same thread are contiguous or dilated. Dilated_x=true is what we currently use for transpose while dilated_x=false is used to support the vectorization of column reduction. .Extend the stack storages that store the partial result address and current output linear index address for each output tensor from a scalar to an array of n elements. .Add curr_iter_index_x to kernelCodegenInfo, to indicate which output element that the compiler is currently generating code for. This information is used to locate the partial result address and output linear index address for the element. .Modify the code generation to use n=2 for column reduction. PiperOrigin-RevId: 228213135
Loading
Please sign in to comment