The Adjoint Method for Back-propagation through Linear Systems
Consider a simple problem where we have an output vector \(z\) of shape \(N\) related to a another vector \(y\) of shape \(N\) by the linear system.
\[A z = y\]We could also write this as
\[z = A^{-1} y.\]In this case the matrix \(A\) and the right-hand side vector \(y\) are both functions of some input parameters \(x\).
Situations like this commonly appear in physics simulations where the bulk of the calculation involves solving a large linear system, but then the actual output is only a scalar function of this solution.
To calculate the derivative of \(z\) with respect to our inputs \(x\) we start by taking the derivative of our original equation,
\[(\partial_x A)z + A(\partial_x z) = \partial_x y,\]And after some rearranging,
\[A (\partial_x z) = (\partial_x) y - (\partial_x A)z,\]and we can see that \(\partial_x z\) is a solution to the same linear system with a different right-hand side. This is a very powerful result since most linear systems are efficiently solved by methods such as LU factorization where the same factorization can be used for multiple right-hand sides of the equation. However this equation still scales linearly with the number of input parameters, we would have to solve this system for each element in \(x\) that we wanted the derivative with respect to. An alternative to this is a method known as the adjoint method.
For simplicity let’s assume that all we care about it is a scalar function \(g(z) = w^T z\), i.e., the sum of all of our output points with some given weight factors. If we assume the weights are independent of our input, the quantity that we want is
\[\partial_x g(z) = w^T (\partial_x z)\]Let’s rearrange our derivative equation a bit and left multiply by \(w^T\),
\[w^T(\partial_x z) = w^T A^{-1}[(\partial_x) y - (\partial_x A)z],\]What is \(w^T A^{-1}\)? It is a vector of size \(N\), let’s call it \(s^T\) for now,
\[s^T = w^T A^{-1}\]Let’s move \(A\) back to the other side,
\[s^T A = w^T,\]and take the transpose of this equation,
\[A^T s = w.\]This vector \(s\) is the solution of what we call the adjoint system, in this case since the system was assumed to be real the adjoint is the same as the transpose. However, what this series of equations means is that we only need to solve this system one additional time to obtain the adjoint solution vector \(s\), and then the derivative with respect to any number of input quantities can be calculated through a relatively efficient dot product.