Make SVD twice differentiable
Right now, SVD gradient computes inf values on the diag of a certain matrix and then sets them to zero. This is fine on the forward pass, but trying to differentiate through that yields NaNs. Instead of computing 1 / (matrix with 0 on the diag) and then setting the diag of that to zero, I compute 1 / (eye + matrix with 0 on the diag) and then (again) set the diag of the result to zero. PiperOrigin-RevId: 236089486
Loading
Please sign in to comment