diff --git a/src/internal_rules/linalg.jl b/src/internal_rules/linalg.jl index 616189bd46..dbb27db819 100644 --- a/src/internal_rules/linalg.jl +++ b/src/internal_rules/linalg.jl @@ -163,6 +163,56 @@ function EnzymeRules.reverse( return (nothing, nothing) end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, + func::Const{typeof(\)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, + A::Annotation{<:LinearAlgebra.AbstractTriangular}, + B::Annotation{<:Array}, + ) + if !(EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config)) + return nothing + end + + retval = func.val(A.val, B.val) + + if EnzymeRules.needs_shadow(config) + N = EnzymeRules.width(config) + dretvals = ntuple(Val(N)) do i + Base.@_inline_meta + dB = B isa Const ? zero(retval) : copy(N == 1 ? B.dval : B.dval[i]) + if !(A isa Const) + dA = N == 1 ? A.dval : A.dval[i] + if A.val isa Union{UnitUpperTriangular, UnitLowerTriangular} + mul!(dB, _zero_unused_elements!(copy(parent(dA)), A.val), retval, -1, 1) + else + mul!(dB, dA, retval, -1, 1) + end + end + ldiv!(A.val, dB) + return dB + end + + if EnzymeRules.needs_primal(config) + if N == 1 + return Duplicated(retval, dretvals[1]) + else + return BatchDuplicated(retval, dretvals) + end + else + if N == 1 + return dretvals[1] + else + return dretvals + end + end + elseif EnzymeRules.needs_primal(config) + return retval + else + return nothing + end +end + const EnzymeTriangulars = Union{ UpperTriangular{<:Complex}, LowerTriangular{<:Complex}, @@ -503,4 +553,3 @@ end # partial derivative of the determinant is the matrix of cofactors EnzymeRules.@easy_rule(LinearAlgebra.det(A::AbstractMatrix), (cofactor(A),)) - diff --git a/test/rules/internal_rules/linear_algebra_rules.jl b/test/rules/internal_rules/linear_algebra_rules.jl index d438f8b476..6752738b7a 100644 --- a/test/rules/internal_rules/linear_algebra_rules.jl +++ b/test/rules/internal_rules/linear_algebra_rules.jl @@ -74,6 +74,48 @@ using Test ) end +tri_upper_vec(A, b) = UpperTriangular(A) \ b +tri_upper_mat(A, B) = UpperTriangular(A) \ B +tri_lower_vec(A, b) = LowerTriangular(A) \ b +tri_lower_mat(A, B) = LowerTriangular(A) \ B +tri_unit_upper_vec(A, b) = UnitUpperTriangular(A) \ b +tri_unit_upper_mat(A, B) = UnitUpperTriangular(A) \ B +tri_unit_lower_vec(A, b) = UnitLowerTriangular(A) \ b +tri_unit_lower_mat(A, B) = UnitLowerTriangular(A) \ B +tri_transpose_upper_vec(A, b) = transpose(UpperTriangular(A)) \ b + +@testset "Forward triangular \\" begin + A = Float64[ + 3.0 0.4 + 0.6 2.7 + ] + b = Float64[0.7, -0.2] + B = Float64[ + 0.6 -0.1 + -0.4 0.7 + ] + + for (f, rhs) in ( + (tri_upper_vec, b), + (tri_upper_mat, B), + (tri_lower_vec, b), + (tri_lower_mat, B), + (tri_unit_upper_vec, b), + (tri_unit_upper_mat, B), + (tri_unit_lower_vec, b), + (tri_unit_lower_mat, B), + (tri_transpose_upper_vec, b), + ) + test_forward(f, Duplicated, (copy(A), Duplicated), (copy(rhs), Duplicated)) + end + + test_forward(tri_upper_vec, DuplicatedNoNeed, (copy(A), Duplicated), (copy(b), Duplicated)) + test_forward(tri_upper_vec, BatchDuplicated, (copy(A), BatchDuplicated), (copy(b), BatchDuplicated)) + test_forward(tri_unit_lower_mat, BatchDuplicatedNoNeed, (copy(A), BatchDuplicated), (copy(B), BatchDuplicated)) + test_forward(tri_upper_vec, Duplicated, (copy(A), Const), (copy(b), Duplicated)) + test_forward(tri_upper_vec, Duplicated, (copy(A), Duplicated), (copy(b), Const)) +end + function tr_solv(A, B, uplo, trans, diag, idx) B = copy(B) LAPACK.trtrs!(uplo, trans, diag, A, B)