[XLA] Add a TriangularSolve HLO.
Previously TriangularSolve() existed in the XLA client library, but was lowered into other HLO ops immediately, preventing us from well-tuned existing BLAS implementations. This change adds a first-class HLO for TriangularSolve. The API of TriangularSolve is chose to match the BLAS TRSM API closely. On the CPU and interpreter backends, the TriangularSolve HLO is immediately expanded to a Call operator that runs the same computation the existing client library would have built. With some cunning, we are able to use XlaBuilder inside a lowering pass, allowing us to keep using the much simpler XlaBuilder API to express the triangular solve computation. Adds a generic OpExpander pass superclass, and refactors GatherExpander to use it. OpExpander is used as the superclass of the new TriangularSolveExpander. On GPU, add direct implementation of TriangularSolve in terms of the cuBlas TRSM implementation. PiperOrigin-RevId: 232987494
Loading
Please sign in to comment