I was implementing differentiation rules of the LU decomposition in ChainRules.jl and needed rules that supported non-square matrices, so I worked them out.
For a square matrix A∈Cm×n, the LU decomposition is useful for solving systems of equations, inverting A, and computing the determinant.
This post assumes familiarity with terminology of forward- and reverse-mode differentiation rules. For a succinct review, I recommend reading the ChainRules guide on deriving array rules.
The LU decomposition of A is
PA=LU.
Where q=min(m,n), L∈Cm×q is a unit lower triangular matrix, that is, a lower triangular matrix whose diagonal entries are all ones. U∈Cq×n is an upper triangular matrix. P∈Rm×m is a permutation matrix whose action on a matrix X (i.e. PX) reorders the rows of X. As a permutation matrix, PPT=PTP=I.
If we have a system of the form of AX=B and would like to solve for X, we can use the LU decomposition to write LUX=PB. Then X=L−1(U−1(PB)). Note that the row-swapping action PB can be computed in-place. We can also easily compute the left-division by the triangular matrices in-place using forward and back substitution.
By setting B in the above equation to be the identity matrix, then we can compute the inverse of A, that is X=A−1.
The determinant of A is det(A)=det(P)det(L)det(U). The determinant of a triangular matrix is the product of its diagonal entries, so det(L)=1, and det(U)=∏i=1nUii. det(P)=(−1)s, where s is the number of row swaps encoded by the permutation matrix. So det(A)=(−1)s∏i=1nUii, which is very cheap to compute from the decomposition.
The LU decomposition can still be computed when A is wide (m<n) or tall (m>n). However, none of the above applications make sense in this case, and I don't know what it's useful for.
Often the LU decomposition is computed using LAPACK subroutines, which cannot be automatically differentiated through. Hence, it is necessary to implement custom automatic differentiation rules. [1] derived a pushforward (AKA forward-mode differentiation rule or Jacobian-vector-product) for the LU decomposition for the square case, but I couldn't find a rule for the wide or tall cases. This is a problem, because I wanted to implement a generic rule for ChainRules.jl that would work for all dense matrices in Julia, where square and non-square dense matrices are all implemented using the same type Matrix
(more specifically, as the union of types StridedMatrix
). It is not ideal to write a custom rule that will work for only a subset of matrices of a given type.
We could always pad A with zeros to make it square. However, if m≪n or m≫n, then this is wasteful. JAX seems to use this approach, though it's possible that internally it does something fancier and doesn't explicitly allocate or operate on the padding.
Thankfully, it's not too hard to work out the pushforwards and pullbacks for the non-square case by splitting the matrices into blocks and working out the rules for the individual blocks. So in this post, we'll review the pushforward for square A, also working out its pullback, and then we'll do the same for wide and tall A.
A pushforward for the LU decomposition for square A is already known [1]. For completeness, I've included its derivation, as well as one for the corresponding pullback.
We start by differentiating the defining equation:
PA˙=L˙U+LU˙
We can solve both sides to get
L−1PA˙U−1=L−1L˙+U˙U−1
L˙ and U˙ must be at least as sparse as L and U, respectively. Hence, U˙ is upper triangular, and because the diagonal of L is constrained to be unit, L˙ will be lower triangular with a diagonal of zeros (strictly lower triangular). Note also that the inverse of a lower/upper triangular matrix is still lower/upper triangular. Likewise, the product of two lower/upper triangular matrices is still lower/upper triangular.
Hence the right-hand side is the sum of a strictly lower triangular matrix L−1L˙ and an upper triangular matrix U˙U−1. Let's introduce the triangularizing operators triu(X), which extracts the upper triangle of the matrix X, and tril−(X), which extracts its strict lower triangle (so that X=tril−(X)+triu(X)).
Introducing an intermediate F˙, we can then solve for L˙ and U˙:
F˙L˙U˙=L−1PA˙U−1=Ltril−(F˙)=triu(F˙)U
The corresponding pullback is
FA=tril−(LHL)+triu(UUH)=PTL−HFU−H
To find the pullback, we use the identity of reverse-mode differentiation and properties of the Frobenius inner product as described in the ChainRules guide on deriving array rules.
Here the identity takes the form
Re⟨L,L˙⟩+Re⟨U,U˙⟩=Re⟨A,A˙⟩.
We want to solve for A, and we do so by first plugging U˙ and L˙ into the left-hand side of this identity, manipulating to look like the right-hand side, and then solving for A.
⟨L,L˙⟩⟨U,U˙⟩=⟨L,Ltril−(F˙)⟩=⟨LHL,tril−(F˙)⟩=⟨U,triu(F˙)U⟩=⟨UUH,triu(F˙)⟩
Because the Frobenius inner product is the sum of all elements of the element-wise product of the second argument and the complex conjugate of the first argument, then for upper triangular U, we have
⟨X,U⟩=ij∑Xij∗Uij=ij∑triu(X)ij∗Uij=⟨triu(X),U⟩.
The same is true for lower-triangular matrices (or, analogously, for any sparsity pattern). Therefore,
⟨L,L˙⟩⟨U,U˙⟩⟨L,L˙⟩+⟨U,U˙⟩=⟨tril−(LHL),F˙⟩=⟨triu(UUH),F˙⟩=⟨tril−(LHL)+triu(UUH),F˙⟩≐⟨F,F˙⟩,
where we have introduced an intermediate F.
Continuing by plugging in F˙, we find
⟨F,F˙⟩=⟨F,L−1PA˙U−1⟩=⟨PTL−HFU−H,A˙⟩
So the pullback of the LU decomposition is written
FA=tril−(LHL)+triu(UUH)=PTL−HFU−H
Note that these expressions use the same elementary operations as solving a system of equations using the LU decomposition, as noted above. The pushforwards and pullbacks can then be computed in-place with no additional allocations.
We can write wide A in blocks A=[A1A2], where A1∈Cm×m and A2∈Cm×(n−m). It will turn out to be very convenient that A1 is square. The LU decomposition in terms of these blocks is written
P[A1A2]=L[U1U2],
where U1 is square upper triangular. This is a system of two equations that we will address separately.
PA1PA2=LU1=LU2
Introducing an intermediate H˙=[H˙1H˙2] with the same block structure as U, the complete pushforward is
H˙F˙U˙1U˙2L˙=L−1PA˙=H˙1U1−1=triu(F˙)U1=H˙2−tril−(F˙)U2=Ltril−(F˙)
Note that the first equation is identical in form to the square LU decomposition. So we can reuse that solution for the pushforward to get
F˙L˙U˙1=L−1PA˙1U1−1=Ltril−(F˙)=triu(F˙)U1
Now, let's differentiate the second equation
PA˙2=L˙U2+LU˙2
and solve for U˙2
U˙2=L−1PA˙2−L−1L˙U2.
Plugging in our previous solution for L˙, we find
U˙2=L−1PA˙2−tril−(F˙)U2.
Introducing an intermediate H=[H1H2] with the same block structure as U, the corresponding pullback is
H1H2A=(tril−(LHL−U2U2H)+triu(U1U1H))U1−H=U2=PTL−HH
Here the reverse-mode identity is
Re⟨L,L˙⟩+Re⟨U1,U˙1⟩+Re⟨U2,U˙2⟩=Re⟨A,A˙⟩.
We plug in L˙, U˙1, and U˙2 to find
Re⟨L,Ltril−(F˙)⟩+Re⟨U1,triu(F˙)U1⟩+Re⟨U2,H˙2−tril−(F˙)U2⟩=Re⟨tril−(LHL)−tril−(U2U2H)+triu(U1U1H),F˙⟩+Re⟨U2,H˙2⟩=Re⟨(tril−(LHL)−tril−(U2U2H)+triu(U1U1H))U1−H,H˙1⟩+Re⟨U2,H˙2⟩
Let's introduce the intermediates
H1H2=(tril−(LHL−U2U2H)+triu(U1U1H))U1−H=U2,
which like H˙ we organize into the block matrix H=[H1H2]. This block structure lets us rewrite the above identity in terms of H and H˙:
Re⟨H1,H˙1⟩+Re⟨H2,H˙2⟩=Re⟨H,H˙⟩
Now we plug in H˙
Re⟨H,H˙⟩=Re⟨H,L−1PA˙⟩=Re⟨PTL−HH,A˙⟩
We have arrived at the desired form and can solve for A:
A=PTL−HH.
The tall case is very similar to the wide case, except now we have the block structure
[P1P2]A=[L1L2]U,
where now L1 is square unit lower triangular, and U is square upper triangular. This gives us the system of equations
P1AP2A=L1U=L2U.
The first equation is again identical to the square case, so we can use it to solve for L˙1 and U˙. Likewise, the same approach we used to solve U˙2 in the wide case can be applied here to solve for L˙2.
Introducing an intermediate H˙=[H˙1H˙2] with the same block structure as L, the complete pushforward is
H˙F˙L˙1L˙2U˙=PA˙U−1=L1−1H˙1=L1tril−(F˙)=H˙2−L2triu(F˙)=triu(F˙)U
Introducing an intermediate H=[H1H2] with the same block structure as L, the corresponding pullback is
H1H2A=L1−H(tril−(L1HL1)+triu(UUH−L2HL2))=L2=PTHU−H
The product of this derivation is this pull request to ChainRules.jl, which includes tests of the rules using FiniteDifferences.jl.
The techniques employed here are general and can be used for differentiation rules for other factorizations of non-square matrices. A recent paper used a similar approach to derive the pushforwards and pullbacks of the QR and LQ decompositions[2].
[1] |
de Hoog F.R., Anderssen R.S., and Lukas M.A. (2011) Differentiation of matrix functionals using triangular factorization. Mathematics of Computation, 80 (275). p. 1585. doi: 10.1090/S0025-5718-2011-02451-8. |
[2] |
Roberts D.A.O. and Roberts L.R. (2020) QR and LQ Decomposition Matrix Backpropagation Algorithms for Square, Wide, and Deep – Real or Complex – Matrices and Their Software Implementation. arXiv: 2009.10071. |