[XLA] Add unit_diagonal option to TriangularSolve.
BLAS implementations of TRSM usually have a "unit_diagonal" option that allows users to specify that the elements on the diagonal should be ignored. This is frequently useful if both lower and upper triangular matrices are packed into the same matrix and the diagonal only belongs to one of them. For example, this is true in the case of LU decomposition in JAX: https://github.com/google/jax/blob/master/jax/numpy/linalg.py#L153 This change pushes that logic into the TriangularSolve() operator, which will allow us to remove it from JAX. This change is also in preparation for adding a first class TriangularSolve HLO that calls into BLAS implementations where applicable. We would like the HLO operator to support the full feature set of any underlying BLAS implementations. PiperOrigin-RevId: 232571346
Loading
Please sign in to comment