Introduction

Broadly speaking, the adjoint method is a method to efficiently calculate gradients of scalar functions in constrained optimization problems with many parameters.

Fitting the parameters of an evolutive system, optimal control, inverse problems, all use the same mathematical backbone to calculate the cost function gradients with respect to the parameters. Depending on the specific structure of cost function and the system to solve ( continuous/discrete, evolutive…), different ad-hoc implementations have been implemented in different disciplines, with their own naming and algorithms. Control applications sometimes name it the costate, in Deep Learning it is called Backpropagation, sometimes is agglutinated inside Sensitivity Analysis…

So, it might be hard to understand at first, how all these different methods, across these different disciplines, share the same underground mathematical formulation. Besides, in many explanations around, I didn’t find any intuitive insight, and they are usually oriented to specific problems, so, even if the difference between formulations, are just cosmetic, it might be hard to identify them as the same mathematical entity.

Having a background in physics and applied math, I first encountered the Adjoint Method in the context of Sensitivity Analysis, Inverse problems, Optimal control… as a trick to efficiently calculate gradients. Later when I delved a little into Deep Learning stuff, with things such as Backpropagation,Reverse mode differentiation…, it took me more time that I’d like to admit to tie it all together with my previous knowledge. So I’d like to try to ease that journey to people that follow a similar path, by motivating and formulating the adjoint method from the standpoint of Applied Mathematics, and detailing later the specific applications in different domains, including Deep Learning, so they can be used as well as a cheat sheet for quick consulting.

A sandbox colab enviroment is linked here, to play around with some ideas explained in the post, code here.

Motivation

Let’s say we have a system that depends on some parameters \(p\) and some input \(x\), producing some output \(u\).

Evaluating that model requires solving the so called forward problem, e.g, if our system is the ODE describing the movement of a damped pendulum, the forward pass will consist in plugging some forcing/impulse term on the RHS, and integrate the equations to evaluate the output u(t,p). If our system is a feed-forward neural net, the forward pass will be calculated by feeding the net some data, propagate it and getting some output \(u(data,p)\).

Many times, we need to travel in the opposite direction, this means, we must solve the inverse problem that answers the question, “which are the inputs/parameters that generate this output”, this usually involves solving some constrained optimization problem and propagating information in the backward direction, which comprises the backward problem.

Feedforward NN Optimal control, LQR controller Seismic tomography
\begin{array}{lc} \mbox{J}: \underset{p}{\mbox{minimize}} & \sum(u_{Data}-(\hat{u}(x_{input},p))^{2} \\ \end{array} \begin{array}{lm} \mbox{G}: & \begin{cases} \hat{u} = \sigma(Wu_{k}) \\ u_{k} = \sigma(W_{k}u_{k-1}) \\ \vdots \\ u_{2} = \sigma(W_{2}u_{1})\\ u_{1} = \sigma(W_{1}u_{input}) \\ \end{cases} \\ \end{array}
\begin{array}{lc} \mbox{J}: \int_{0}^{tf}1/2(u^{*}Qu+p^{*}Rp)d\tau \\ \end{array} \begin{array}{lm} \mbox{G}: & \dot{u} = Au+Bp \end{array}
\begin{array}{lc} \mbox{J}: \sum_{sensor data} ||u_{jreading,\tau}-u_{j}(t=\tau)||^{2} \\ \end{array} \begin{array}{lm} \mbox{G}: & \ddot{u} = A(m)u+b(m) \\ I.C & B.C \end{array}
Notable examples that adhere to this description. Regression task with a FF NN. Finding the optimal \(u(t)\) that stabilizes a linear system. Finding the source terms \(b\) and propagation parameters, \(A\) , such as density in Seismic tomography.

Let’s illustrate this with some specific example. Let’s assume we have some shallow tank, filled with some steady liquid sustance, and we are interesting in obtaining a specific spatial temperature profile.

To control the temperature \(u(x,y)\), we can inject some reactive flow \(p\) at different positions, that reacts very quickly and it’s equivalent to some net “temperature injection rate/reactive_mass”, \(C_{r}\).

The temperature profile will be modelled by some Reaction-Diffusion equation with some source term depending on our parameters \(p_{ij}\). We’ll consider a linear model for simplicity.

\begin{array}{l} \underset{p}{argmin}J(p) = \sum_{D}(u_{ijD}-u(p,x_{i}, y{j})) \\ subject \quad to \quad g(u(x,y),p) = D\partial_{x}^{2}u + \sigma u - C\sum^{N_{p}} p_{ij}\delta(x_{i},y_{j}) = 0\\ B.C \end{array} Let's work with the discretized version of the system to avoid now the unnecesary complications of functional analysis at this time. \begin{array}{l} \underset{p}{argmin}J(p) = \sum_{D}(u_{ijD}-u(p,x_{i}, y{j})) \\ subject \quad to \quad Au = Cp \\ B.C \end{array}

The \(u_{ijD}\) term represent the setpoint temperature at different \(ij\) positions.

Now \(C\) is a matrix with \(0\) at spatial entries where there are no flow injection and \(C_{r}\) where we have some injection \(p_{ij}\)

The forward problem then, will consist in solving a linear system for the \(N_{u}\) nodes of our mesh.

To solve the minimization problem, we’ll have to calculate the gradients of the cost function \(J(p)\) with respect to the parameters, \(d_{p} J\).

Applying the chain rule we get \(d_{p} J = \partial_{u} J\partial_{p}u + \partial_{p}J\). Here the term enforcing the constraint, \(\partial_{p}u\), is the nasty bit. Because we have to evaluate the derivative of the temperature at every node \(u_{km}\) respect to every \(p_{ij}\), \(\partial_{p}u\) has \(N_{u} \times N_{p}\) elements where \(N_{p}\) is the number of controls, not cool.

Trying to find \(d_{p}J\) naively by finite differences would require solving \(N_{p}\) times \(g(u,p) = 0\), not cool.

Trying to solve numerically \(\partial_{p}u\) in \(A\partial_{p_{ij}}u = C\delta_{ij}\) will imply solving \(N_{p}\) systems of similar complexity than the forward problem , not cool.

To get around this, the key realization here is that our objective function is just an scalar function \(J(u(p))\), and at every \(u(p)\) , we only care about the \(u\) direction \(J\) is most sensitive to, this is \(\partial_{u}J\). This can be made obvious with the following example.

Let’s imagine in the last problem that we only care about controlling the temperature at some specific \(u(x_{k},y_{m})\) node,

\begin{array}{l} \underset{p}{argmin}J(p) = (u_{kmD}-u(p,x_{k}, y{m})^{2} \\ subject \quad to \quad Au = Cp \\ B.C \end{array}

It is obvious here that we only need to solve one additional system, \(A\partial_{p}u_{km}=Cp\) , because we only care about the \(\partial_{p}u_{km}\) term.

But if we think about it, there is nothing special about the \(u_{km}\) direction, if there are more \(u's\) in the cost function, the direction we are interested about will be a direction with non zero components in different \(u's\) but a single direction anyway. We can imagine \(\partial_{u}J\) as some vector pointing in a specific direction in \(u\) space. If \(J(u)=J(u_{km})\), only the \(km\) entry will be non zero, \((0,0,0,\partial_{ukm}J,0,\dots, 0)\). Similarly if the dependence is given by some \(Q\) cost matrix over all nodes \(J(u) = u^{T}Qu\), \(\partial_{u}J\) at some certain \(p\), will just be a less sparse vector in \(u\) space.

So how can we work with this single direction instead of all the \(\partial_{p}u\) beast?, that’s where the Adjoint method comes into play. Let’s remember some definitions.

The adjoint \(B\) of a linear operator is another operator that satisfy

\begin{equation} <Bv,w> = <v,Aw> \end{equation}

Actually, we are so used to apply this in the context of matrix algebra, that is easy to miss what is behind.

When dealing with vector spaces, the inner product \(<v,w>\), will be some expression like this \(v^{T}Aw\) (if A is a diagonal unitary matrix, we get pythagoras). We can first calculate Aw, or we can move A to the left \((A^{T}v)^{T}w\), here we are using the adjoint matrix of A, which is just the transpose matrix; rather than the matrix specific application, is better to keep the \(<Bv,w> = <v,Aw>\) expression in mind, because it is much more general, and if we are dealing with operators on functional spaces for instance, we can think of the adjoint operator as the “generalization “ of the transpose matrix in linear algebra2.

The \(d_{p}J\) increments due to perturbations of the parameters \(\Delta p\) is given by the inner product

\begin{array}{l} dJ(p) = <\partial_{u}J, \partial_{p}u \Delta p> = <\partial_{u}J, -\partial_{u}g^{-1}\partial_{p}g \Delta p> \end{array}

Where we have expressed \(\partial_{p}u\) as the solution of the linear system we get by differentiating the constraint \(g(u,p)\) respect to the parameters \(p\), (I know, the deal here is avoding the calculation of \(\partial_{p}u\), nonetheless, it is an auxiliary insightful step).

\begin{equation} \partial_{u}g\partial_{p}u +\partial_{p}g= 0 \end{equation}

By definition \(\partial_{p}J\) will be the thing on the left in the \(dJ(p) = <\partial_{p}J, \Delta p>\) expression, so we must find some \(<G(p),\Delta p>\) expression, we achieve this by moving the \(\partial_{u}g^{-1}\partial_{p}g\) operator to the left side,

\begin{array}{l} dJ(p) = <\partial_{u}J, -\partial_{u}g^{-1}\partial_{p}g \Delta p> = <( -\partial_{u}g^{-1}\partial_{p}g)^{T} \partial_{u}J, \Delta p> \end{array}

The thing on the left is our variable of interest, \(\partial_{p}J\), but expressed like that, it doesn’t look that we’ve made any advance, but there is in fact a linear system in disguise there, of only \(N_{u}\) unknowns. We can avoid solving \(\partial_{u}g\partial_{p}u = -\partial_{p}g\) with \(N_{u} \times N_{p}\) unkowns, and unfold the much smaller system of \(N_{u}\) unknowns by setting the term \(\partial_{u}J^{T}\partial_{u}g^{-1}\) to some unknown variable \(\lambda\) of \(N_{u}\) dimensions.

\begin{array}{l} -\partial_{u}J\partial_{u}g^{-1} = \lambda^{T} \\ -\partial_{u}J^{T} = \partial_{u}g^{T}\lambda \end{array}

Which is the adjoint system.

Once we find the adjoint \(\lambda\) , \(\partial_{p}J\) is just calculated by sustituting \(\partial_{u}(g^{-1})^{T}\partial_{u}J\) by \(\lambda\) in the \(dJ(p)\) expression.

\begin{equation} <\partial_{p}g^{T}\lambda, \Delta p> = <\partial_{p}J,\Delta p> \rightarrow d_{p}J = \lambda^{T}\partial_{p}g \end{equation}

Tying it all together

Let’s summarize the case problem and the application of the adjoint method.

We have some system (dynamical system, neural network, spatial PDE…) with state \(u(p,x)\) , whose response to an external input is expressed implicitly as a function \(g(u,p)= 0\).

We want to minimize some scalar function of the state of the system \(J(u(p),p)\) respect to \(p\). When the number of \(p\) is large, the naive calculation of \(\partial_{p}u\) is intractable, so we write down another auxiliary additional system, the adjoint system, that help us to calculate \(d_{p}J\) by only taking into account the direction of change of \(u\) that \(J\), is most sensitive to, \(\partial_{u}J\).

Let’s assume for simplicity that \(u(p)\) and \(p\) are arrays of dimension \(N_{u},N_{p}\) and \(J\) doesn’t depends explicitly on \(p\) ( if it does we just need to add the \(\partial_{p}J\) term in the final expression).

\begin{array}{l} \underset{p}{argmin} \quad J(p) \\ subject \quad to \quad g(u,p) = 0\\ \end{array}

Let’s write down the adjoint equation we derivated earlier

\begin{array}{l} \underset{p}{argmin}J(p) \\ subject \quad to \quad g(u,p) = 0 \\ \quad to \quad \partial_{u}J(u(p),p) = -\partial_{u}g^{T}(u,p)\lambda \end{array}

where \(\lambda\) is an array of the same dimension as \(u\).

Once we solve this system, \(d_{p}J\) is just \(\lambda^{T}\partial_{p}g\).

To solve the adjoint system we first have to solve the forward system, once we get the value of \(u\), for some \(p\), we can get \(\partial_{u}g^{T}(u,p)\) and \(\partial_{u}J\), with this, we can solve the adjoint system.

Once the insight of the method has been understood, the specific derivation of the different types of states \(u\) and parameters \(p\) (discrete recurrent case, the continuous case …) will be mostly mechanical, and is extended in the appendix, in order not to clutter the main flow of the post, using the more formal approach of Lagrange multipliers.

In the next section I’ll write down the specific adjoint system we get for systems of different structure that cover the most usual cases, with notable examples of each one, these are:

  • \(u\) and \(p\) are vectors that obey some general non linear equation \(g(u,p)= 0\) ( this is the case covered so far, e.g discretized spatial pde).
  • \(u\) is the state of a recurrent discrete system , where \(u_{k} = f(u_{k-1},p)\) (e.g feedforward NN, discretized ODE).
  • \(u(t,p)\) is a function of time, and \(J\) is evaluated on a continuous trajectory (e.g optimal control).
  • \(u(t,p)\) is a function of time, and \(J\) evaluated on discrete datapoints (eg. ODEnet, Fitting ODE parameters).

Adjoint cheat sheet

The general problem we will be solving is finding the parameters \(P\) that minimize a function/functional \(J\) of the state \(X\), subject to a state equation \(G\)

\begin{equation} J:X \times P \rightarrow \Re \end{equation}

\begin{equation} G:X \times P \rightarrow X \end{equation}

Where X and P might be scalars, vectors, time dependent functions…

For the continuous case where \(u\) is a function of time, the following notation is used

\begin{array}{l} t_{0} = 0 \rightarrow \mbox{initial time} \\ T_{f} = T \rightarrow \mbox{final_time}\\ \quad \tau = T-t \\ \end{array}

Here \(\tau\) is the reversed time variable that we use to parametrize the adjoint ODE.

Discrete state

$$u$$ $$\Re^{N_{u}}$$ e.g Discretized spatial PDE $$\Re^{N_{t}}\times\Re^{N_{u_{k}}}$$ e.g Recurrent NN, discretized ODEs
$$J$$ $$J(u(p),p)$$ $$J(u_{1},...,u_{T},p)$$
Forward system $$g(u,p) = 0$$ \begin{array}{l} g(u_{T},u_{T-1},p) = 0 \\ g(u_{T-1},u_{T-2},p) = 0 \\ \vdots \\ g(u_{2},u_{1},p) = 0 \\ g(u_{1},u_{0},p) = 0 \\ g(u_{0},p) = 0 \end{array}
Adjoint/backward system $$\partial_{u}J(u(p),p)^{T} = -\partial_{u}g^{T}(u,p)\lambda$$ \begin{array}{l} \partial_{u}J+\lambda_{T}\partial_{u_{T}}g_{T} = 0 \\ \lambda_{T}\partial_{u_{T-1}}g_{T}+\lambda_{T-1}\partial_{u_{T-1}}g_{T-1} =0 \\ \vdots \\ \lambda_{1}\partial_{u_{0}}g_{1}+\lambda_{0}\partial_{u_{0}}g_{0} = 0 \end{array}
$$d_{p}J$$ $$d_{p}J = \partial_{p}J+\lambda^{T}\partial_{p}g $$ \begin{array}{l} d_{p} J_{aug} = \partial_{p} J + \\ \lambda_{T}\partial_{p}g_{T}+\lambda_{T-1}\partial_{p}g_{T-1}+\ldots + \\ \lambda_{1}\partial_{p}g_{1}+ \\ \lambda_{0}\partial_{p}g_{0} \\ \end{array}

Continuous state

$$u$$ $$u(t,p)$$ $$u(t,p)$$
$$J$$ $$ J(u(t),p) = \int_{0}^{T} f(u(t),p) \,dt$$ eg. LQR controller, Kalman Filter $$J(x,p) = \sum_{i}^f(x_{data_{i}},\hat{x_{i}},p) = \int_{0}^{T} \sum_{i}^f(x_{data_{i}},\hat{x_{i}},p) \delta(t_{i}) dt$$ eg. Fitting ODE params, ODEnet
Forward system \begin{array}{l} g(\dot{u}(t),u(t),p) = 0 \\ g_{0}(u_{0},p) = 0 \end{array} \begin{array}{l} g(\dot{u},u,p) = 0 \\ g_{0}(u_{0},p) = 0 \end{array}
Adjoint/backward system \begin{array}{l} \dot{\lambda(\tau)^{T}}\partial_{\dot{u}}g - \lambda(\tau)^{T}\dot{\partial_{\dot{u}}}g + \lambda(\tau)^{T}\partial_{u}g = -\partial_{u}f \\ \lambda(\tau = 0) = 0 \\ \mu=\lambda^{T}\partial_{\dot{u}}g\partial_{u}g_{0}^{-1} \end{array} \begin{array}{l} \dot{\lambda(\tau)^{T}}\partial_{\dot{x}}g - \lambda(\tau)^{T}\dot{\partial_{\dot{x}}}g + \lambda(\tau)^{T}\partial_{x}g = 0 \\ \lambda(\tau_{i}) = -\partial_{x}f|_{\tau_{i}} \\ \mu=\lambda^{T}\partial_{\dot{x}}g\partial_{x}g_{0}^{-1} \end{array}
$$d_{p}J$$ \begin{array}{l} d_{p}J = \int_{0}^{T} \partial_{p}f+\lambda^{T}(\tau)\partial_{p}gdt+\mu\partial_{p}g_{0}\Bigr|_{t = 0} \end{array} \begin{array}{l} d_{p}J = \int_{0}^{T} \partial_{p}f+\lambda^{T}(\tau)\partial_{p}gdt+\mu\partial_{p}g_{0}\Bigr|_{t = 0} \end{array}

Specific applications

With this framework, writing down the adjoint equation for different systems is just a matter of sustitution, I’ll try to make this more clear with some notable examples, by deriving the backpropagation algorithm for feedforward neural networks, linear dynamical systems, and continuous depth Deep Learning, ODEnet.

Backpropagation

To obtain the expression of the backpropagated gradients, we only have to substitute the specific expression of \(g\) , and use the formula of the adjoint system for vector recurrent states.

Training a feedforward neural network can be formulated as the following constrained optimization problem

\begin{array}{lc} \underset{p}{\mbox{minimize}} & J(u_{data},u_{L},p=W) \\ \end{array} \begin{array}{lcm} \mbox{forward system} & g(u,W) = 0 = & \begin{cases} g(u_{L},u_{L-1},p) = u_{L}-\sigma(W_{L}u_{L-1}) = 0 \\ g(u_{L-1},u_{L-2},p) = u_{L-1}-\sigma(W_{L-1}u_{L-2}) = 0 \\ \vdots \\ g(u_{2},u_{1},p) = u_{2}-\sigma(W_{2}u_{1}) = 0\\ g(u_{1},u_{0},p) = u_{1}-\sigma(W_{1}u_{0}) = 0 \\ g(u_{0},p) = u_{0}-u_{input} = 0 \end{cases} \\ \end{array} \begin{array}{lmc} \mbox{backward/adjoint system} & \begin{cases} \partial_{u}J+\lambda_{T}^{T}\partial_{u_{T}}g_{T} = 0 \\ \lambda_{T}^{T}\partial_{u_{T-1}}g_{T}+\lambda_{T-1}^{T}\partial_{u_{T-1}}g_{T-1} =0 \\ \vdots \\ \lambda_{1}^{T}\partial_{u_{0}}g_{1}+\lambda_{0}^{T}\partial_{u_{0}}g_{0} = 0 \\ \end{cases} & \rightarrow \begin{cases} \partial_{u}J+\lambda_{T}^{T} = 0 \\ -\lambda_{T}^{T}\sigma'(W_{L}u_{L-1})W_{L}+\lambda_{T-1}^{T} =0 \\ \vdots \\ -\lambda_{1}^{T}\sigma'(W_{L}u_{0})W_{1}+\lambda_{0}^{T} = 0 \\ \end{cases} \\ \end{array}

So the gradients of the cost function \(J\) respect to the weights of the \(L-k\) layer is

\begin{array}{lc} \mbox{Gradient} & \begin{equation} d_{W_{L-k}} J_{aug} =\lambda_{L-k}^{T}\partial_{W_{L-k}} g_{L-k}=-\lambda_{L-k}^{T}\partial_{W_{L-k}} \sigma(W_{L-k}x_{L-k-1}) \end{equation} \end{array}

If you were taking forward derivatives in the case of a Feedforward neural network or a discretized evolutive PDE, you wouldn’t need the fancy derivation of the adjoint method in the context of a constrained optimization problem to realise you are wasting in vane a lot of effort, if \(u_{k}\) is the state of the forward system at time/layer \(k\), when you write the expression of the gradients of the cost function evaluated at the final time/layer \(T\) respect to the parameters at the \(k\) step, you would get a matrix product expression that you can evaluate starting from the right to the left , or starting from the left to the right

\begin{array}{l} \partial_{p_{k}}J=\partial_{u_{L}}J^{T}\sigma'(u_{L})W_{L}\sigma'(u_{L-1})W_{L-1}\ldots\sigma'(u_{k})W_{k} \\ \leftarrow \quad propagation \quad of \quad \partial_{p_{K}}u_{K} \quad (size \quad N_{u_{k}}\times N_{p_{k}}) \\ \rightarrow \quad propagation \quad of \quad \partial_{u_{L}}J^{T}\sigma'(u_{L})W_{L} \quad (size \quad N_{u_{k}}) \end{array}

Obviously you would go from left to right because the \(\partial_{u_{L}}J\) is contracting the heavy \(\partial_{p}u\) matrices into an array. In terms of matrix algebra, \(\rightarrow\) is traversed by doing vector jacobian products, and \(\leftarrow\) is traversed by doing jacobian vector products, the jacobian vector product answer the question “How much my output change with \(\Delta p\) perturbations?”, while vector jacboian product answers the question “How much a linear functional changes by perturbations \(\Delta p\)?”3 . A more extended, NN focused explanation of this, explained neately

Linear dynamical systems

It is interesting to work the case of fitting a data point on the end of the trajectory of a linear ODE system, because being able to write down the analytical expression can give some useful insights about the relation of the forward adjoint dynamics.

\begin{equation} J(u,p) = (u_{T}-\hat{u}(T,p))^{2} \end{equation} \begin{array}{lm} \mbox{subject to} & \begin{cases} \dot{u} = pAu \\ u(t_{0}) = u_{0} \end{cases} \end{array} We just have to use the formula for the continuous case evaluated on discrete datapoints and we get the equations \begin{cases} \partial_{\dot{u}}g = \mathbb 1 \\ \dot{\partial_{\dot{u}}g} = 0 \\ \partial_{u}g = -pA \\ \partial_{p}g = -Au \\ \lambda(\tau = 0) = 2(u_{T}-\hat{u}(T,p)) \\ \mu=\lambda(\tau = T) \end{cases} \begin{array}{lm} \mbox{adjoint system} & \begin{cases} \dot{\lambda^{T}}-\lambda^{T}pA = 0 \\ \lambda^{T}(\tau = 0) = 2(u_{T}-\hat{u}(T,p)) \end{cases} \end{array}

The solutions to both systems are matrix exponentials with coefficients being \(A\) in the forward system and \(A^{T}\) in the backward system, this means that the exploding/vanishing trends will be the same, but they’ll rotate in opposing directions

\begin{array}{lm} \mbox{solution forward} & \begin{cases} u = e^{pAt}u_{0} \end{cases} \end{array} \begin{array}{lm} \mbox{solution adjoint} & \begin{cases} \lambda = e^{pA^{T}\tau}\lambda_{0} = e^{pA^{T}\tau}2(u_{T}-\hat{u}(T,p)) \end{cases} \end{array}

The forward system will propagate forward \(u_{0}\) with \(e^{pAt}\) and, the adjoint system will propagate \(\partial_{u}f|_{T}\) backwards

\begin{array}{lm} \mbox{Gradient = $d_{p} J(u,p)$ = } & \begin{cases} \int_{\tau = 0}^{\tau = T} (e^{pA^{T}\tau}2(u_{T}-\hat{u}(\tau = 0,p)))(A\hat{u}(\tau))d\tau \end{cases} \end{array}

Continuous-depth limit Deep Learning, ODEnet

Extending the number of layers of a feedforward neural network to the continuous limit5, is essentially the same than moving from a discrete ODE system scheme to its continuous expression.

To make the analogy, NN/ode more explicit, let’s compare a basic feedforward neural network with the euler discretization of a linear dynamical system.

\begin{array}{lcmsd} \mbox{forward system} & g(u,W) = 0 = & \begin{cases} u_{L}-\sigma(W_{L}u_{L-1}) = 0 \\ u_{L-1}-\sigma(W_{L-1}u_{L-2}) = 0 \\ \vdots \\ u_{2}-\sigma(W_{2}u_{1}) = 0\\ u_{1}-\sigma(W_{1}u_{0}) = 0 \\ u_{0}-u_{input} = 0 \end{cases} & \underset{\sigma(Wu)=(A+\mathbb 1)u}{\rightarrow} & \begin{cases} u_{T}-(A\Delta t+\mathbb 1)u_{T-1} = 0 \\ u_{T-1}-(A\Delta t+\mathbb 1)u_{T-2} = 0 \\ \vdots \\ u_{2}-(A\Delta t+\mathbb 1)u_{1} = 0\\ u_{1}-(A\Delta t+\mathbb 1)u_{0} = 0 \\ u_{0}-u_{input} = 0 \end{cases} \\ \end{array}

If we employ a more widespread variant of NN, that employ skip connections, Resnets, the resemblance is even more striking,

\begin{array}{lcms} \mbox{forward system} & \begin{cases} u_{L} = \sigma(W_{L}u_{L-1}) \\ u_{L-1} = \sigma(W_{L-1}u_{L-2}) \\ \vdots \\ u_{2} = \sigma(W_{2}u_{1}) \\ u_{1} = \sigma(W_{1}u_{0}) \\ u_{0} = u_{input} \end{cases} & \underset{skip \quad connections}{\rightarrow} & \begin{cases} u_{L} = \sigma(W_{L}u_{L-1})+u_{L-1} \\ u_{L-1} = \sigma(W_{L-1}u_{L-2})+u_{L-2} \\ \vdots \\ u_{2} = \sigma(W_{2}u_{1})+u_{1} \\ u_{1} = \sigma(W_{1}u_{0})\\ u_{0} = u_{input} \end{cases} \\ \end{array} Adding more layers to a resnet, with increansingly smaller steps, the network state evolution equation, \begin{equation} u_{t+1} = A(u_{t},\theta)+u_{t} \end{equation} starts resembling an euler integration step. Taking this to the limit of small steps, we can parametrize the continuous dynamics of a neural network state by an ODE. \begin{equation} \dot{u} = A(u,\theta) \end{equation} Now the gradients must be calculated using the continuous formulation of the adjoint method, and the forward and backward passes are calculated with regular ODE solvers.
\begin{equation} J(u,p) = \int_{0}^{T} \sum_{i}^f(u_{Data},u(t=T,p,u_{input}))\delta(T) dt \end{equation} \begin{array}{l} g(u,\dot{u},t,p) = \dot{u}-A(u,p,t) = 0 \\ g_{0}(u_{input},p) = u(t = 0)-u_{input} = 0 \end{array} \begin{array}{lc} \mbox{general backward/adjoint system} & \begin{cases} \dot{\lambda(\tau)}\partial_{\dot{u}}g - \lambda(\tau)\dot{\partial_{\dot{u}}}g + \lambda(\tau)\partial_{u}g = -\partial_{u}f \\ \lambda(\tau = 0) = 0 \\ \mu=\lambda\partial_{\dot{u}}g\partial_{u}g_{0}^{-1} \end{cases} \\ \end{array} \begin{cases} \partial_{\dot{x}}g = \mathbb 1 \\ \dot{\partial_{\dot{u}}g} = 0 \\ \partial_{u}g = -\partial_{u}A \\ \partial_{p}g = -\partial_{p}A \\ \lambda(\tau = 0) = -\sum_{i}^{Ndata}\partial_{u}f \\ \mu= 0 \end{cases} \begin{array}{lc} \mbox{odenet backward/adjoint system} & \begin{cases} \dot{\lambda(\tau)} = - \lambda(\tau)\partial_{u}g \\ \lambda(\tau = 0) = -\sum_{i}^{Ndata}\partial_{u}f \end{cases} \\ \end{array} \begin{array}{lc} \mbox{odenet Gradient} & \begin{equation} d_{p}J_{aug} = \int_{T}^{0} \lambda(t)\partial_{p}A(t)dt \end{equation} \\ \end{array}

Solving the Adjoint problem for “free” with automatic differentiation libraries

Usually, if we are optimizing over a simulation or a neural network, we won’t have to formulate the adjoint systemexplicitly, it will be built automatically by the software once the forward pass has been defined, provided that all the operations in the forward pass are compatible with the automatic differentiation engine of the framework.

One of the most basic operations that automatic differentiation libraries implement is the vector jacobian, \(v^TW\), product, which allow us to backpropagate the gradient of a scalar function backwards to the weights in the case of explicit layers (most NN use explicit layers).

Most of these libraries nowadays have implemented the different variants of the adjoint system discussed here, even simulation frameworks such as FeniCS or Ansys, implement the automatic calculation of the adjoint state, so we rarely have to care to implement it.

Here I’ll show an example of the solution of the forward and backward problem for the analytically solved linear system before, implementing a Runge-Kutta scheme in pytorch that we can think of as our layer, to compute the forward system.

Of course this is only used for demonstration purposes, we should use the continuous implementation of the adjoint which is discretization independent, and can handle implicit schemes4.


def make_system():

    def system_linear(x,t,params):

        k,gamma = params["k"], params["gamma"]

        dx1dt = x[0][1]
        dx2dt = -k*x[0][0]-gamma*x[0][1]

        return torch.cat([dx1dt.view(-1,1), dx2dt.view(-1,1)], dim = 1)

    k = torch.tensor([1])
    gamma = torch.nn.Parameter( torch.tensor([0.01]) )
    params = {"k":k,"gamma":gamma}
    system = partial(system_linear, params = params)

    return system, params

Pytorch just build a graph when we recurrently calculate the solution of the forward problem, used to calculate the gradient that propagates backwards

def _rk4_step(fun, yk, tk, h):

    k1 = fun(yk, tk)
    k2 = fun(yk + h/2*k1, tk + h/2)
    k3 = fun(yk + h/2*k2, tk + h/2)
    k4 = fun(yk + h*k3, tk + h)

    yk_next = yk + h/6*(k1+2*k2+2*k3+k4)

    return yk_next

def rk4(fun, y0, t, retain_grad = False):

    y = []

    h = t[1]-t[0]
    yk = y0
    y.append(yk)

    for i in range(1,len(t)):
        yknext = _rk4_step(fun, yk, t[i-1], h)
        yk = yknext

        if retain_grad:
            yk.retain_grad()

        y.append(yk)

    return y

The forward pass is the solution of our system for a given input/forcing….

def forward_pass(x0,T, system):

    out = rk4(system, x0, T, retain_grad = True)

    return out

...




  out = forward_pass(x0, T, system)  #We solve the forward problem g(u,p) = 0, in this case a explicit system



  xpred = torch.cat([out[i] for i in indices_data], dim = 0)  #Here we just concatenate all the outputs at the datapoints of interest



J = torch.mean( torch.square(xpred-xdata))   #We  calculate the loss J




J.backward()  #We slve the adjoint problem , the gradients are stored in the different p, and then used by the optim

Fitting one data point at time \(T_{f}\). The inner trajectory is the ground truth, the red arrow, the gradient of the loss function at \(T_{f}\), is propagated back to calculate the gradient at every instant. The analytical expression for both systems was discussed in the linear system section.

More than one data point.

Appendix

Derivation for vector states

\begin{array}{lc} \underset{p}{\mbox{min}} & J(u,p) = f(u,p) \\ \end{array} \begin{array}{lc} \mbox{subject to} & g(u,p) = 0 \end{array} \begin{array}{l} \mbox{with } u\in\Re^{nu},p\in\Re^{np} \\ g: \Re^{nu}\times\Re^{np}\rightarrow \Re^{nu} \\ f: \Re^{nu}\times\Re^{np}\rightarrow \Re \\ \mbox{being } \partial_{u}g \mbox{ non singular} \end{array}

We are interested in finding the gradient \(d_{p}J\). To that end we take derivatives of \(J\) and \(g\) respect to the parameteres \(p\)

\begin{array}{lc} d_{p}J = \partial_{p}f^{T} +\partial_{u}f^{T}\partial_{p}u \\ \partial_{u}g\partial_{p}u=-\partial_{p}g \end{array}

Substituting \(\partial_{p}u\) in \(d_{p}J\), we arrive at

\begin{array}{l} d_{p}J = \partial_{p}f^{T} -\partial_{u}f^{T}\partial_{u}g^{-1}\partial_{p}g
\end{array}

The thing is, \(\partial_{u}g^{-1}\), appearing in \(d_{p}J\) as the product of \(\partial_{u}f^{T}\partial_{u}g^{-1}\), makes it unnecesary to solve the larger system \(\partial_{u}g\partial_{p}u=-\partial_{p}g\), we can instead find the vector \(\lambda^{T}\in\Re^{nu}\) that satisfies \(\partial_{u}f^{T}\partial_{u}g^{-1}=\lambda^{T}\), arriving at the adjoint system.

\begin{equation} -\partial_{u}f^{T}\partial_{u}g^{-1}=\lambda^{T}\rightarrow \underset{adjoint \quad system}{\partial_{u}g^{T}\lambda = -\partial_{u}f} \end{equation}

There is a more practical and general derivation of the adjoint method, that is through the use of lagrange multipliers. Let’s start by building the augmented objective function by appending the constraint with lagrange multiplier \(\lambda^{T}\). The important thing to realize here, is , because \(g(x,p)=0\), \(\lambda^{T}\) can take any value we want without altering the solution of the problem.

\begin{array}{lc} \underset{p}{\mbox{min}} & J_{aug}(u,p) = f(u,p) + \lambda^{T}g(u,p)
\end{array}

Taking derivatives respect to the parameter \(p\) of the augmented function, and getting common factor \(\partial_{p}u\), we arrive at

\begin{equation} d_{p}J_{aug} = \partial_{p}f+\partial_{u}f\partial_{p}u+\lambda^{T}(\partial_{p}g+\partial_{u}g\partial_p{u}) \end{equation}

\begin{equation} d_{p}J_{aug} = \partial_{p}f +\lambda^{T}\partial_{p}{g} + (\partial_{u}f+\lambda^{T}\partial_{u}g)\partial_{p}u \end{equation}

Now, we euercise our freedom of choice for \(\lambda^{T}\) and choose such, that cancels the \(\partial_{p}u\) term, arriving againg to the adjoint equation.

\begin{equation} \partial_{u}g^{T}\lambda = -\partial_{u}f \end{equation}

Derivation for discrete recurrent systems (e.g FF NN, discretized ODE)

Such as in the case of neural networks, and dynamical systems, we are specially interested in systems, where the state is transformed sequentially. Let’s assume for simplicity that the system is evaluated at its final stage, last layer in a nn, final position in a control problem.

\begin{array}{lc} \underset{p}{\mbox{minimize}} & J(u_{T},p) \\ \end{array} \begin{array}{lcm} \mbox{subject to} & g(u,p) = 0 = & \begin{cases} g(u_{T},u_{T-1},p) = 0 \\ g(u_{T-1},u_{T-2},p) = 0 \\ \vdots \\ g(u_{2},u_{1},p) = 0\\ g(u_{1},u_{0},p) = 0 \\ g(u_{0},p) = 0 \end{cases} \\ \end{array}

The derivation in the last section is still valid, in this case, but doesnt give much insight. Let’s create the augmented objective function by adding each constraint equation with a lagrange multiplier.

\begin{equation} J_{aug}(u_{T},p) = J(u_{T},p) + \lambda_{T}g(u_{T},u_{T-1},p) + \lambda_{T-1}g(u_{T-1},u_{T-2},p) \ldots \lambda_{1}g(u_{1},u_{0},p)+ \lambda_{0}g(u_{0},p) \end{equation}

Let’s again take derivatives respect to \(p\). Our objective is to get rid of the \(d_{p}u_{Tk}\) terms, so let’s take common factor of each one of them

\begin{eqnarray} d_{p} J_{aug} = \partial_{p} J + \partial_{u}J \partial_{p}u+ \lambda_{T}(\partial_{u_{T}}g_{T}d_{p}u_{T} + \partial_{u_{T-1}}g_{T}d_{p}u_{T-1}+\partial_{p}g_{T})+ \\ \lambda_{T-1}(\partial_{u_{T-1}}g_{T-1}d_{p}u_{T-1} + \partial_{u_{T-2}}g_{T-1}d_{p}u_{T-2}+\partial_{p}g_{T-1})+ \\ \ldots \\ \lambda_{1}(\partial_{u_{1}}g_{1}d_{p}u_{1} + \partial_{u_{0}}g_{1}d_{p}u_{0}+\partial_{p}g_{1})+ \lambda_{0}(\partial_{u_{0}}g_{0}d_{p}u_{0} + \partial_{p}g_{0}) \end{eqnarray}
\begin{eqnarray} d_{p} J_{aug} = \partial_{p} J + (\partial_{u}J+\lambda_{T}\partial_{u_{T}}g_{T})d_{p}u_{T}+ (\lambda_{T}\partial_{u_{T-1}}g_{T}+\lambda_{T-1}\partial_{u_{T-1}}g_{T-1})d_{p}u_{T-1} + \\ \ldots \\ (\lambda_{1}\partial_{u_{0}}g_{1}+\lambda_{0}\partial_{u_{0}}g_{0})d_{p}u_{0}+ \lambda_{T}\partial_{p}g_{T}+\lambda_{T-1}\partial_{p}g_{T-1}+\ldots + \lambda_{1}\partial_{p}g_{1}+ \lambda_{0}\partial_{p}g_{0} \end{eqnarray}

Now, we have a different lagrange multiplier in every paranthesis that we can use to make that term vanish, doing so, we arrive at a system of equations that can be solved backwards.

It is worth noting, that this is just a more formal reformulation of the algorithm of backpropagation in deep learning

\begin{array}{lm} \mbox{backward/adjoint system} & \begin{cases} \partial_{u}J+\lambda_{T}\partial_{u_{T}}g_{T} = 0 \\ \lambda_{T}\partial_{u_{T-1}}g_{T}+\lambda_{T-1}\partial_{u_{T-1}}g_{T-1} =0 \\ \vdots \\ \lambda_{1}\partial_{u_{0}}g_{1}+\lambda_{0}\partial_{u_{0}}g_{0} = 0 \\ \end{cases} \\ \end{array} \begin{array}{lm} \mbox{Gradient} & \begin{equation} d_{p} J_{aug} = \partial_{p} J + \lambda_{T}\partial_{p}g_{T}+\lambda_{T-1}\partial_{p}g_{T-1}+\ldots + \lambda_{1}\partial_{p}g_{1}+ \lambda_{0}\partial_{p}g_{0} \end{equation} \end{array} \begin{array}{lm} \mbox{forward/state system} & \begin{cases} g(u_{T},u_{T-1},p) = 0 \\ g(u_{T-1},u_{T-2},p) = 0 \\ \vdots \\ g(u_{2},u_{1},p) = 0\\ g(u_{1},u_{0},p) = 0 \\ g(u_{0},p) = 0 \end{cases} \\ \end{array}

Derivation for time dependent states

Now let’s deal with the case where the state is a time dependent function, this would be the problem we would have, for instance, in optimal control or trying to fit the parameters of an ODE to data.

\begin{array}{lc} \underset{p}{\mbox{minimize}} & J(u(t),p) = \int_{0}^{T} f(u(t),p) \,dt \\ \end{array} \begin{array}{lc} \mbox{subject to} & \begin{cases} g(\dot{u},u,p) = 0 \\ g_{0}(u_{0},p) = 0 \end{cases} \\ \end{array}

Let’s build the augmented objective function by plugging the constraints with their respective lagrange multipliers, taking into account that $\lambda^{T}(t)$ is now a function of time

\begin{equation} J_{aug} = \int_{0}^{T} f(u(t),p)+ \lambda(t)^{T}g(\dot{u},u,p) dt + \mu g_{0}(u_{0},p) \end{equation}

We aim to calculate $d_{p}Jaug$, our goal is the same as before, take gradients respect to $p$, and try to get rid of $\partial_{p}u(t)$ by using the lagrange multipliers

\begin{equation} d_{p}J_{aug} = \int_{0}^{T} \partial_{p}f+ \partial_{u}f\partial_{p}u + \lambda(t)^{T}(\partial_{\dot{u}}g\underset{\partial_{u}\dot{u}\partial_{p}u=\dot{\partial_{p}u}}{\dot{\partial_{p}u}} + \partial_{u}g\partial_{p}u +\partial_{p}g) dt + \mu (\partial_{u}g_{0}\partial_{p}u + \partial_{p}g_{0})\Bigr|_{t = 0} \end{equation}

The mischevious term here is \(\int_{0}^{T} \lambda(t)\partial_{\dot{u}}g\dot{\partial_{p}u}\), we need to have everything in terms of \(\partial_{p}u\), not \(\dot{\partial_{p}u}\), but because everything appears as inner products with \(\lambda\), we can find the adjoint of the \(\frac{d}{dt}\), in more amenable terms, we can integrate by parts to move the \(\frac{d}{dt}\) to the left of the integral.

\begin{equation} d_{p}J_{aug} = \int_{0}^{T} \lambda(t)\partial_{\dot{u}}g\dot{\partial_{p}u} = \lambda\partial_{\dot{u}}g\partial_{p}{u}\Bigr|_{t=0}^{t=T}-\int_{0}^{T} \dot{\lambda\partial_{\dot{u}}}g\partial_{p}{u} \end{equation}

We are free to choose the \(\lambda, \mu\) terms, we do so by setting \(\lambda(T) = 0\) to get rid of \(\dot{\partial_{p}{u}(T)}\), and we arrange a similar solution in \(t=0\) by seeking to make zero \((\mu\partial_{u}g_{0}-\lambda\partial_{\dot{u}}g)\partial_{p}{u}(t=0)\).

\begin{equation} d_{p}J_{aug} = \int_{0}^{T} \partial_{p}f+ \partial_{u}f\partial_{p}u + \lambda(t)(\partial_{\dot{u}}g{\dot{\partial_{p}u}} + \partial_{u}g\partial_{p}u +\partial_{p}g) dt + \mu (\partial_{u}g_{0}\partial_{p}u + \partial_{p}g_{0})\Bigr|_{t = 0} \end{equation} \begin{equation} d_{p}J_{aug} = \int_{0}^{T} (\partial_{u}f - \dot{\lambda}\partial_{\dot{u}}g-\lambda\dot{\partial_{\dot{u}}}g+\lambda\partial_{u}g)\partial_{p}u +\partial_{p}f+\lambda\partial_{p}gdt + (\mu\partial_{u}g_{0}-\lambda\partial_{\dot{u}}g)\partial_{p}{u})\Bigr|_{t = 0} +\mu\partial_{p}g_{0}\Bigr|_{t = 0} \end{equation}
\begin{array}{lc} \mbox{forward/state system} & \begin{cases} g(\dot{u},u,p) = 0 \\ g_{0}(u_{0},p) = 0 \end{cases} \\ \end{array}

By making zero the term that accompanies \(\partial_{p}u\) in the integral, we arrive at the continuous adjoint system, we change time variables so the system is integrated backwards from \(\tau|_{T_{0}}\).

\begin{array}{l} t \rightarrow \tau \\ with \quad t = T-\tau \end{array} \begin{array}{lc} \mbox{backward/adjoint system} & \begin{cases} \dot{\lambda(\tau)}\partial_{\dot{u}}g - \lambda(\tau)\dot{\partial_{\dot{u}}}g + \lambda(\tau)\partial_{u}g = -\partial_{u}f \\ \lambda(\tau = 0) = 0 \\ \mu=\lambda\partial_{\dot{u}}g\partial_{u}g_{0}^{-1} \end{cases} \\ \end{array} \begin{array}{lc} \mbox{Gradient} & \begin{equation} d_{p}J_{aug} = \int_{0}^{T} \partial_{p}f+\lambda\partial_{p}gdt+\mu\partial_{p}g_{0}\Bigr|_{t = 0} \end{equation} \\ \end{array}

It is worth noting that discretizing this system, we arrive at the set of equations of last section

Time dependent states with data evaluated at discrete times

In the case of ODEnet or fitting a ODE to data, our cost function \(J\) usually will be evaluated on discrete datapoints, that enter in the integral of the last section as dirac deltas, and appear in the RHS of the adjoint equation. When we have an ODE with \(\delta\) as source terms, is like a sudden impulse that makes the solution of the ODE jump, presenting a discontinuity.

To solve the adjoint system we’d have to solve seperatelly the sections between datapoints,adding the jump at the end of a piece, then this will act as the initial condition for the neut piece.

\begin{equation} J(u,p) = \int_{0}^{T} \sum_{i}^{Ndata}(u_{t_{i}}-\hat{u}(t_{i},p))^{2} \delta(t_{i}) dt \end{equation} \begin{equation} \partial_{u}f = -2\sum_{i}^{Ndata}(u_{t_{i}}-\hat{u}(t_{i},p)) \delta(t_{i}) \end{equation} \begin{equation} \dot{\lambda}\partial_{\dot{u}}g -\lambda\dot{\partial_{\dot{u}}}g+\lambda\partial_{u}g = - \sum_{i}^{Ndata}A(u_{data},\hat{u})\delta(\tau_{i}) \end{equation} \begin{equation} \lambda(\tau_{i}^{+})-\lambda(\tau_{i}^{-}) = -A(u_{data},\hat{u}) = 2(u_{t_{i}}-\hat{u}(t_{i},p)) \end{equation}

Footnotes and References

  1. PDE-constrained optimization and the adjoint method https://cs.stanford.edu/~ambrad/adjoint_tutorial.pdf

  2. Brief, nice and practical explanation of the adjoint calculation for Function operators among many other things

    J. Nathan Kutz Advanced Differential Equations: Asymptotics & Perturbations

  3. Neat, great and extended explanation of this in the context of NNs.

    Zico Kolter, David Duvenaud, Matt Johnson Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyonds

  4. Pytorch implementation of the adjoint method for continuous states.

    https://github.com/rtqichen/torchdiffeq
  5. Odenet paper

    Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud Neural Ordinary Differential Equations

  6. Notes on Adjoint Methods for 18.335 1 Introduction Steven G. Johnson Notes Adjoint, Steven G. Johnson