Differentiating the LU decomposition

February 03, 2021

I was implementing differentiation rules of the LU decomposition in ChainRules.jl and needed rules that supported non-square matrices, so I worked them out.

Introduction

For a square matrix ACm×nA \in \mathbb{C}^{m \times n}, the LU decomposition is useful for solving systems of equations, inverting AA, and computing the determinant.

This post assumes familiarity with terminology of forward- and reverse-mode differentiation rules. For a succinct review, I recommend reading the ChainRules guide on deriving array rules.

Definition

The LU decomposition of AA is

PA=LU.P A = L U.

Where q=min(m,n)q = \min(m, n), LCm×qL \in \mathbb{C}^{m \times q} is a unit lower triangular matrix, that is, a lower triangular matrix whose diagonal entries are all ones. UCq×nU \in \mathbb{C}^{q \times n} is an upper triangular matrix. PRm×mP \in \mathbb{R}^{m \times m} is a permutation matrix whose action on a matrix XX (i.e. PXP X) reorders the rows of XX. As a permutation matrix, PPT=PTP=IP P^\mathrm{T} = P^\mathrm{T} P = I.

Uses

If we have a system of the form of AX=BA X = B and would like to solve for XX, we can use the LU decomposition to write LUX=PBL U X = P B. Then X=L1(U1(PB))X = L^{-1} (U^{-1} (P B)). Note that the row-swapping action PBP B can be computed in-place. We can also easily compute the left-division by the triangular matrices in-place using forward and back substitution.

By setting BB in the above equation to be the identity matrix, then we can compute the inverse of AA, that is X=A1X = A^{-1}.

The determinant of AA is det(A)=det(P)det(L)det(U)\det(A) = \det(P) \det(L) \det(U). The determinant of a triangular matrix is the product of its diagonal entries, so det(L)=1\det(L) = 1, and det(U)=i=1nUii\det(U) = \prod_{i=1}^n U_{ii}. det(P)=(1)s\det(P) = (-1)^s, where ss is the number of row swaps encoded by the permutation matrix. So det(A)=(1)si=1nUii\det(A) = (-1)^s \prod_{i=1}^n U_{ii}, which is very cheap to compute from the decomposition.

The LU decomposition can still be computed when AA is wide (m<nm < n) or tall (m>nm > n). However, none of the above applications make sense in this case, and I don't know what it's useful for.

Motivation

Often the LU decomposition is computed using LAPACK subroutines, which cannot be automatically differentiated through. Hence, it is necessary to implement custom automatic differentiation rules. [1] derived a pushforward (AKA forward-mode differentiation rule or Jacobian-vector-product) for the LU decomposition for the square case, but I couldn't find a rule for the wide or tall cases. This is a problem, because I wanted to implement a generic rule for ChainRules.jl that would work for all dense matrices in Julia, where square and non-square dense matrices are all implemented using the same type Matrix (more specifically, as the union of types StridedMatrix). It is not ideal to write a custom rule that will work for only a subset of matrices of a given type.

We could always pad AA with zeros to make it square. However, if mnm \ll n or mnm \gg n, then this is wasteful. JAX seems to use this approach, though it's possible that internally it does something fancier and doesn't explicitly allocate or operate on the padding.

Thankfully, it's not too hard to work out the pushforwards and pullbacks for the non-square case by splitting the matrices into blocks and working out the rules for the individual blocks. So in this post, we'll review the pushforward for square AA, also working out its pullback, and then we'll do the same for wide and tall AA.

Square AA

A pushforward for the LU decomposition for square AA is already known [1]. For completeness, I've included its derivation, as well as one for the corresponding pullback.

Pushforward

We start by differentiating the defining equation:

PA˙=L˙U+LU˙P \dot{A} = \dot{L} U + L \dot{U}

We can solve both sides to get

L1PA˙U1=L1L˙+U˙U1L^{-1} P \dot{A} U^{-1} = L^{-1} \dot{L} + \dot{U} U^{-1}

L˙\dot{L} and U˙\dot{U} must be at least as sparse as LL and UU, respectively. Hence, U˙\dot{U} is upper triangular, and because the diagonal of LL is constrained to be unit, L˙\dot{L} will be lower triangular with a diagonal of zeros (strictly lower triangular). Note also that the inverse of a lower/upper triangular matrix is still lower/upper triangular. Likewise, the product of two lower/upper triangular matrices is still lower/upper triangular.

Hence the right-hand side is the sum of a strictly lower triangular matrix L1L˙L^{-1} \dot{L} and an upper triangular matrix U˙U1\dot{U} U^{-1}. Let's introduce the triangularizing operators triu(X)\operatorname{triu}(X), which extracts the upper triangle of the matrix XX, and tril(X)\operatorname{tril}_-(X), which extracts its strict lower triangle (so that X=tril(X)+triu(X)X = \operatorname{tril}_-(X) + \operatorname{triu}(X)).

Introducing an intermediate F˙\dot{F}, we can then solve for L˙\dot{L} and U˙\dot{U}:

F˙=L1PA˙U1L˙=Ltril(F˙)U˙=triu(F˙)U\begin{aligned} \dot{F} &= L^{-1} P \dot{A} U^{-1}\\ \dot{L} &= L \operatorname{tril}_-(\dot{F})\\ \dot{U} &= \operatorname{triu}(\dot{F}) U \end{aligned}

Pullback

The corresponding pullback is

F=tril(LHL)+triu(UUH)A=PTLHFUH\begin{aligned} \overline{F} &= \operatorname{tril}_-(L^\mathrm{H} \overline{L}) + \operatorname{triu}(\overline{U} U^\mathrm{H})\\ \overline{A} &= P^\mathrm{T} L^{-\mathrm{H}} \overline{F} U^{-\mathrm{H}} \end{aligned}

To find the pullback, we use the identity of reverse-mode differentiation and properties of the Frobenius inner product as described in the ChainRules guide on deriving array rules.

Here the identity takes the form

ReL,L˙+ReU,U˙=ReA,A˙.\operatorname{Re}\left\langle \overline{L}, \dot{L} \right\rangle + \operatorname{Re}\left\langle \overline{U}, \dot{U} \right\rangle = \operatorname{Re}\left\langle \overline{A}, \dot{A} \right\rangle.

We want to solve for A\overline{A}, and we do so by first plugging U˙\dot{U} and L˙\dot{L} into the left-hand side of this identity, manipulating to look like the right-hand side, and then solving for A\overline{A}.

L,L˙=L,Ltril(F˙)=LHL,tril(F˙)U,U˙=U,triu(F˙)U=UUH,triu(F˙)\begin{aligned} \left\langle \overline{L}, \dot{L} \right\rangle &= \left\langle \overline{L}, L \operatorname{tril}_-(\dot{F}) \right\rangle = \left\langle L^\mathrm{H} \overline{L}, \operatorname{tril}_-(\dot{F}) \right\rangle\\ \left\langle \overline{U}, \dot{U} \right\rangle &= \left\langle \overline{U}, \operatorname{triu}(\dot{F}) U \right\rangle = \left\langle \overline{U} U^\mathrm{H}, \operatorname{triu}(\dot{F}) \right\rangle\\ \end{aligned}

Because the Frobenius inner product is the sum of all elements of the element-wise product of the second argument and the complex conjugate of the first argument, then for upper triangular UU, we have

X,U=ijXijUij=ijtriu(X)ijUij=triu(X),U.\left\langle X, U \right\rangle = \sum_{ij} X_{ij}^* U_{ij} = \sum_{ij} \operatorname{triu}(X)_{ij}^* U_{ij} = \left\langle \operatorname{triu}(X), U \right\rangle.

The same is true for lower-triangular matrices (or, analogously, for any sparsity pattern). Therefore,

L,L˙=tril(LHL),F˙U,U˙=triu(UUH),F˙L,L˙+U,U˙=tril(LHL)+triu(UUH),F˙F,F˙,\begin{aligned} \left\langle \overline{L}, \dot{L} \right\rangle &= \left\langle \operatorname{tril}_-(L^\mathrm{H} \overline{L}), \dot{F} \right\rangle\\ \left\langle \overline{U}, \dot{U} \right\rangle &= \left\langle \operatorname{triu}(\overline{U} U^\mathrm{H}), \dot{F} \right\rangle\\ \left\langle \overline{L}, \dot{L} \right\rangle + \left\langle \overline{U}, \dot{U} \right\rangle &= \left\langle \operatorname{tril}_-(L^\mathrm{H} \overline{L}) + \operatorname{triu}(\overline{U} U^\mathrm{H}), \dot{F} \right\rangle \doteq \left\langle \overline{F}, \dot{F} \right\rangle, \end{aligned}

where we have introduced an intermediate F\overline{F}.

Continuing by plugging in F˙\dot{F}, we find

F,F˙=F,L1PA˙U1=PTLHFUH,A˙\left\langle \overline{F}, \dot{F} \right\rangle = \left\langle \overline{F}, L^{-1} P \dot{A} U^{-1} \right\rangle = \left\langle P^\mathrm{T} L^{-\mathrm{H}} \overline{F} U^{-\mathrm{H}}, \dot{A} \right\rangle

So the pullback of the LU decomposition is written

F=tril(LHL)+triu(UUH)A=PTLHFUH\begin{aligned} \overline{F} &= \operatorname{tril}_-(L^\mathrm{H} \overline{L}) + \operatorname{triu}(\overline{U} U^\mathrm{H})\\ \overline{A} &= P^\mathrm{T} L^{-\mathrm{H}} \overline{F} U^{-\mathrm{H}} \end{aligned}


Note that these expressions use the same elementary operations as solving a system of equations using the LU decomposition, as noted above. The pushforwards and pullbacks can then be computed in-place with no additional allocations.

Wide AA

We can write wide AA in blocks A=[A1A2]A = \begin{bmatrix}A_1 & A_2 \end{bmatrix}, where A1Cm×mA_1 \in \mathbb{C}^{m \times m} and A2Cm×(nm)A_2 \in \mathbb{C}^{m \times (n - m)}. It will turn out to be very convenient that A1A_1 is square. The LU decomposition in terms of these blocks is written

P[A1A2]=L[U1U2],P\begin{bmatrix}A_1 & A_2\end{bmatrix} = L \begin{bmatrix}U_1 & U_2\end{bmatrix},

where U1U_1 is square upper triangular. This is a system of two equations that we will address separately.

PA1=LU1PA2=LU2\begin{aligned} P A_1 &= L U_1\\ P A_2 &= L U_2\\ \end{aligned}

Pushforward

Introducing an intermediate H˙=[H˙1H˙2]\dot{H} = \begin{bmatrix} \dot{H}_1 & \dot{H}_2 \end{bmatrix} with the same block structure as UU, the complete pushforward is

H˙=L1PA˙F˙=H˙1U11U˙1=triu(F˙)U1U˙2=H˙2tril(F˙)U2L˙=Ltril(F˙)\begin{aligned} \dot{H} &= L^{-1} P \dot{A}\\ \dot{F} &= \dot{H}_1 U_1^{-1}\\ \dot{U}_1 &= \operatorname{triu}(\dot{F}) U_1\\ \dot{U}_2 &= \dot{H}_2 - \operatorname{tril}_-(\dot{F}) U_2\\ \dot{L} &= L \operatorname{tril}_-(\dot{F})\\ \end{aligned}

Note that the first equation is identical in form to the square LU decomposition. So we can reuse that solution for the pushforward to get

F˙=L1PA˙1U11L˙=Ltril(F˙)U˙1=triu(F˙)U1\begin{aligned} \dot{F} &= L^{-1} P \dot{A}_1 U_1^{-1}\\ \dot{L} &= L \operatorname{tril}_-(\dot{F})\\ \dot{U}_1 &= \operatorname{triu}(\dot{F}) U_1 \end{aligned}

Now, let's differentiate the second equation

PA˙2=L˙U2+LU˙2P \dot{A}_2 = \dot{L} U_2 + L \dot{U}_2

and solve for U˙2\dot{U}_2

U˙2=L1PA˙2L1L˙U2.\dot{U}_2 = L^{-1} P \dot{A}_2 - L^{-1} \dot{L} U_2.

Plugging in our previous solution for L˙\dot{L}, we find

U˙2=L1PA˙2tril(F˙)U2.\dot{U}_2 = L^{-1} P \dot{A}_2 - \operatorname{tril}_-(\dot{F}) U_2.


Pullback

Introducing an intermediate H=[H1H2]\overline{H} = \begin{bmatrix} \overline{H}_1 & \overline{H}_2 \end{bmatrix} with the same block structure as UU, the corresponding pullback is

H1=(tril(LHLU2U2H)+triu(U1U1H))U1HH2=U2A=PTLHH\begin{aligned} \overline{H}_1 &= \left(\operatorname{tril}_-(L^\mathrm{H} \overline{L} - \overline{U}_2 U_2^\mathrm{H}) + \operatorname{triu}(\overline{U}_1 U_1^\mathrm{H})\right) U_1^{-\mathrm{H}}\\ \overline{H}_2 &= \overline{U}_2\\ \overline{A} &= P^\mathrm{T} L^{-\mathrm{H}} \overline{H} \end{aligned}

Here the reverse-mode identity is

ReL,L˙+ReU1,U˙1+ReU2,U˙2=ReA,A˙.\operatorname{Re}\left\langle \overline{L}, \dot{L} \right\rangle + \operatorname{Re}\left\langle \overline{U}_1, \dot{U}_1 \right\rangle + \operatorname{Re}\left\langle \overline{U}_2, \dot{U}_2 \right\rangle = \operatorname{Re}\left\langle \overline{A}, \dot{A} \right\rangle.

We plug in L˙\dot{L}, U˙1\dot{U}_1, and U˙2\dot{U}_2 to find

ReL,Ltril(F˙)+ReU1,triu(F˙)U1+ReU2,H˙2tril(F˙)U2=Retril(LHL)tril(U2U2H)+triu(U1U1H),F˙+ReU2,H˙2=Re(tril(LHL)tril(U2U2H)+triu(U1U1H))U1H,H˙1+ReU2,H˙2\begin{aligned} & \operatorname{Re}\left\langle \overline{L}, L \operatorname{tril}_-(\dot{F}) \right\rangle + \operatorname{Re}\left\langle \overline{U}_1, \operatorname{triu}(\dot{F}) U_1 \right\rangle + \operatorname{Re}\left\langle \overline{U}_2, \dot{H}_2 - \operatorname{tril}_-(\dot{F}) U_2 \right\rangle\\ &= \operatorname{Re}\left\langle \operatorname{tril}_-(L^\mathrm{H} \overline{L}) - \operatorname{tril}_-(\overline{U}_2 U_2^\mathrm{H}) + \operatorname{triu}(\overline{U}_1 U_1^\mathrm{H}), \dot{F} \right\rangle + \operatorname{Re}\left\langle \overline{U}_2, \dot{H}_2 \right\rangle\\ &= \operatorname{Re}\left\langle \left(\operatorname{tril}_-(L^\mathrm{H} \overline{L}) - \operatorname{tril}_-(\overline{U}_2 U_2^\mathrm{H}) + \operatorname{triu}(\overline{U}_1 U_1^\mathrm{H})\right) U_1^{-\mathrm{H}}, \dot{H}_1 \right\rangle + \operatorname{Re}\left\langle \overline{U}_2, \dot{H}_2 \right\rangle \end{aligned}

Let's introduce the intermediates

H1=(tril(LHLU2U2H)+triu(U1U1H))U1HH2=U2,\begin{aligned} \overline{H}_1 &= \left(\operatorname{tril}_-(L^\mathrm{H} \overline{L} - \overline{U}_2 U_2^\mathrm{H}) + \operatorname{triu}(\overline{U}_1 U_1^\mathrm{H})\right) U_1^{-\mathrm{H}}\\ \overline{H}_2 &= \overline{U}_2, \end{aligned}

which like H˙\dot{H} we organize into the block matrix H=[H1H2]\overline{H} = \begin{bmatrix} \overline{H}_1 & \overline{H}_2 \end{bmatrix}. This block structure lets us rewrite the above identity in terms of H\overline{H} and H˙\dot{H}:

ReH1,H˙1+ReH2,H˙2=ReH,H˙\operatorname{Re}\left\langle \overline{H}_1, \dot{H}_1 \right\rangle + \operatorname{Re}\left\langle \overline{H}_2, \dot{H}_2 \right\rangle = \operatorname{Re}\left\langle \overline{H}, \dot{H} \right\rangle

Now we plug in H˙\dot{H}

ReH,H˙=ReH,L1PA˙=RePTLHH,A˙\operatorname{Re}\left\langle \overline{H}, \dot{H} \right\rangle = \operatorname{Re}\left\langle \overline{H}, L^{-1} P \dot{A} \right\rangle = \operatorname{Re}\left\langle P^\mathrm{T} L^{-\mathrm{H}} \overline{H}, \dot{A} \right\rangle

We have arrived at the desired form and can solve for A\overline{A}:

A=PTLHH.\overline{A} = P^\mathrm{T} L^{-\mathrm{H}} \overline{H}.


Tall AA

The tall case is very similar to the wide case, except now we have the block structure

[P1P2]A=[L1L2]U,\begin{bmatrix}P_1 \\ P_2 \end{bmatrix} A = \begin{bmatrix}L_1 \\ L_2 \end{bmatrix} U,

where now L1L_1 is square unit lower triangular, and UU is square upper triangular. This gives us the system of equations

P1A=L1UP2A=L2U.\begin{aligned} P_1 A &= L_1 U\\ P_2 A &= L_2 U. \end{aligned}

Pushforward

The first equation is again identical to the square case, so we can use it to solve for L˙1\dot{L}_1 and U˙\dot{U}. Likewise, the same approach we used to solve U˙2\dot{U}_2 in the wide case can be applied here to solve for L˙2\dot{L}_2.

Introducing an intermediate H˙=[H˙1H˙2]\dot{H} = \begin{bmatrix} \dot{H}_1 \\ \dot{H}_2 \end{bmatrix} with the same block structure as LL, the complete pushforward is

H˙=PA˙U1F˙=L11H˙1L˙1=L1tril(F˙)L˙2=H˙2L2triu(F˙)U˙=triu(F˙)U\begin{aligned} \dot{H} &= P \dot{A} U^{-1}\\ \dot{F} &= L_1^{-1} \dot{H}_1\\ \dot{L}_1 &= L_1 \operatorname{tril}_-(\dot{F})\\ \dot{L}_2 &= \dot{H}_2 - L_2 \operatorname{triu}(\dot{F})\\ \dot{U} &= \operatorname{triu}(\dot{F}) U \end{aligned}

Pullback

Introducing an intermediate H=[H1H2]\overline{H} = \begin{bmatrix} \overline{H}_1 \\ \overline{H}_2 \end{bmatrix} with the same block structure as LL, the corresponding pullback is

H1=L1H(tril(L1HL1)+triu(UUHL2HL2))H2=L2A=PTHUH\begin{aligned} \overline{H}_1 &= L_1^{-\mathrm{H}} \left(\operatorname{tril}_-(L_1^\mathrm{H} \overline{L}_1) + \operatorname{triu}(\overline{U} U^\mathrm{H} - L_2^\mathrm{H} \overline{L}_2)\right)\\ \overline{H}_2 &= \overline{L}_2\\ \overline{A} &= P^\mathrm{T} \overline{H} U^{-\mathrm{H}} \end{aligned}

Implementation

The product of this derivation is this pull request to ChainRules.jl, which includes tests of the rules using FiniteDifferences.jl.

Conclusion

The techniques employed here are general and can be used for differentiation rules for other factorizations of non-square matrices. A recent paper used a similar approach to derive the pushforwards and pullbacks of the QRQR and LQLQ decompositions[2].

References

[1] de Hoog F.R., Anderssen R.S., and Lukas M.A. (2011) Differentiation of matrix functionals using triangular factorization. Mathematics of Computation, 80 (275). p. 1585. doi: 10.1090/S0025-5718-2011-02451-8.
[2] Roberts D.A.O. and Roberts L.R. (2020) QR and LQ Decomposition Matrix Backpropagation Algorithms for Square, Wide, and Deep – Real or Complex – Matrices and Their Software Implementation. arXiv: 2009.10071.

If you have questions or suggestions, feel free to open an issue.
If you found this useful, please share and follow @sethaxen on Twitter.