Neural Ordinary Differential Equations

Introduction

As I am truly fascinated by the application of mathematics in general, I was intrigued when exposed to the field of deep learning: the amalgamation of two of my favourite fields, math and computer science. As I continued to explore, I saw the pattern of rapid evolution of computer hardware, faster algorithms and better and more powerful architectures, yet, mathematical concepts were left behind. We accelerated through GPUs and came up with more and more model architectures but the equations and algorithms were always created with computer science/applications in mind, and it was for the good as well.

Imagine my fascination upon discovering a new novel approach to model architecture, and that finally, the ancient treasure trove of mathematics was opened again! It was a breadth of fresh air, and some heavy mathematics to tackle! And hence I am writing this essay to convey to the best of my ability what I have learned, and to do it in the most intuitive way possible. (To anyone wondering, no, I do not think the latest improvements before this concept did not utilise mathematical concepts or were math-deprived, but it was my perspective as far as I had studied ML). A warning is due in order, as this essay may get long and a familiarity with all of the previous concepts (in this blog) and rudimentory integration and differential equations is necessary.

A differential approach

We begin with looking where even differential equations can be utilised in this field. Differential Equations are basically one of the most useful concepts in mathematics, physics, chemistry, biology, economics etc… basically every field that deals with change. In all deep neural models, the process is very simple: initialise a hidden state, perform a forward pass, and the result will be the next hidden state. Let us envelope the forward pass in a neat function f, which takes in the previous hidden state and the parameters (that it’ll tweak) and outputs the next hidden state. We do this because different model architectures perform forward pass in different ways (which means this method can applied to any model, regardless of the architecture!).

ht+1=f(ht, θt)

We are missing an important (and quite a revolutionary) step here: the addition of the input back to the output. This was formerly introduced in the case of ResNets, to tackle previous problems and to better the model quality: and it worked like magic. So we add an input at the end of the forward pass to better improve the model. Why stop there, why not add the input after every block, every layer ? So here we are:

ht+1=ht+f(ht, θt)

Now comes the neat part: remember how I said differential equations study the effects of change ? (Well actually calculus itself does, anyways) If you notice, the only thing changing here is our hidden state: ht. Why not simply move it to the left side ?

ht+1ht=f(ht, θt)

Here we essentially are parametrising the change in our hidden state. And a little calculus-1 helps: why not take it to the limit ? In doing so, we get the most out of the model, by simply moving from the world of discrete to the world of continuous. (Since now we are moving in a continuous fashion, we can parametrise h with respect to t or time, and write it as h(t)).

limt ht+1htt=f(h(t), t, θt)dh(t)dt=f(h(t), t, θt)

Here we have our desired differential equation ! Basically this gives us the final output, which, as we know by now, must be compared with the ground truth in order to calculate our loss function ! Let L be our loss function (again this is left ambiguous, and any loss function can be used), x(t0) be our input at time t0, and our output will be till t1. The final output of a forward pass will be the “sum” of all the changes x(t) goes through (how the input changes after each layer) being added to our starting point, and since we are in the world of the continuous, we integrate.

t0t1dx(t)dtdt=t0t1f(x(t), t, θt) dt x(t1)=x(t0)+t0t1f(x(t), t, θt) dt

Hence is our final loss:

L(x(t1))= L(x(t0)+t0t1f(x(t), t, θt) dt)

My usual approach is to go through a numerical example with an input, but in this case, the theory is a tad bit different than the approach itself. The theory is quite complex and initially unintuitive (at least compared to the other models), and must be dealt with without an example first.

Reverse Mode Automatic Differentiation: A differential approach

While we may have come up with the formulation of a differential equation and used it to parametrise our model, that was just the forward pass. As is the case in deep learning in general, the backpropagation is the main, and often the complex, part. The process of acquiring the necessary gradients from our loss function can be done in the traditional way, but turns out to be quite computationally expensive and inaccurate. The authors of thepaper came up with a different approach: the adjoint sensitivity method. This method is not new, and was first introduced in1962. The method has many similiarities with backprop, and only differs in a slight manner. There are a few different perspectives of understanding what we are upto here, but I would use the most straight-forward one (other perspectives may introduce certain terms which although useful, may not be relevant here).

The adjoint sensitivity method works by introducing an adjoint state, which acts as an intermediary between the Loss and other parameters we would want the gradients of. Let’s say our hidden state, parametrised by time, is h(t). Our adjoint state is simply the gradient of the loss with respect to this hidden state.

a(t)=dLdh(t)

Here, it can also be seen as a Langrange multiplier, but why is it so, and how it can be seen as such is not necessary in this context, as it would lead to a whole new field of constrained optimization and wouldn’t help us going forward. I recommend going through thisvideo if anyone is interested. Moving on, our main concern is knowing the dynamics of how the adjoint state changes, in order to solve for the gradient we are looking for. In order to get it, we simply need to remember the chain rule. Recall the chain rule in case of a discrete forward pass, and a continuous one:

Discrete: dLdht= dLdht+e...dht+3dht+2dht+1dht Continuous: dLdh(t)= dLdh(t+e)dh(t+e)h(t)

A key note to remember is the adjoint state can be defined for any function which is parametrised by time, and just in this case we are looking at the hidden state. Lets look at how our function would change after e time, that is, t -> t+e. We simply add the integrated differential to our initial input.

h(t+e)= h(t)+tt+ef(h(t), t, θt) dt=Tht

Let us call it Tht for ease of notation. We now simply plug in this equation into our previous continouos derivative (adding the partial derivative sign since the equation is multi-variate):

dLdh(t)= dLdh(t+e)dh(t+e)h(t)= a(t+e)Thth(t) a(t)= a(t+e)Thth(t)

Our final goal is to find out how this adjoint state changes, basically :

da(t)dt

We need this value in order to sum(integrate) over it in order to find the gradients of the loss with respect to the underlying function (we will see that later). Next we attempt to solve for the derivative using limits.

Taking the limit: da(t)dt=lime0a(t+e)a(t)eSubstituting a(t):= lime0a(t+e)a(t+e)Thth(t)eSubstituting Tht with its taylor series expansion around h(t),O(e2)denotes terms ofhigher powers, which are simply multiplied by higherpowers of e:=lime0a(t+e)a(t+e)(h(t)+e(f(h(t), t, θt)) + Ο(e2))h(t)eUnpacking the partial derivatives, I denotes the identity matrix: =lime0a(t+e)a(t+e)(I+e(f(h(t), t, θt))h(t)+Ο(e2))eOpening up the brackets:=lime0a(t+e)a(t+e)a(t+e)e(f(h(t), t, θt))h(t)+Ο(e2)eCancelletion of a(t+e):=lime0ea(t+e)(f(h(t), t, θt))h(t)+Ο(e2)eSplitting the fraction:=lime0ea(t+e)(f(h(t), t, θt))h(t)e+Ο(e2)eCancelletion of the denominator:=lime0a(t+e)(f(h(t), t, θt))h(t)+Ο(e)Taking the limit:da(t)dt=a(t)(f(h(t), t, θt))h(t)

With this straigh-forward proof, we have acquired the necessery equation in order to calculate the gradient of a function with respect to our Loss function. Recall the original definition of a(t), where t can be any point in time. Let suppose a(t0) is the desired gradient:

a(t0)=dLdh(t0)

We solve this just as we have solved during the forward pass: simply add the integration over the differential back to the input. Except this time, we add it back to the output, which is a(tN), and integrate not from 0 to N but backwards: from N to 0. This can be seen as integrating backwards in time.

a(t0)= a(tN)+tN0da(t)dtdta(t0)= a(tN)tN0a(t)(f(h(t), t, θt))h(t)dt

It is very important to understand that this method and equation helps us calculate the gradient of a function with respect to the loss function. Under certain criterias, we can put in any function (paramatrised by time), get that function’s adjoint state and calculate the gradients backward in time. There are essentially three variables we have: the input x(t), our parameters θ, and the initial and end times t0 and tN, so just t. Hence our three variables would have three differential equations, though since parameters do not change during the forward pass and we are calculating derivatives with respect to time:

θ(t)t=0, t(t)t=1

And we can define an adjoint state for each variable, with their own integrals. Since the time equation and gradients aren’t required in the model, I won’t be going into the details. Let suppose we require the gradient of θ(t) at the initial time 0, hence following the above equations and remembering that the change till the final value would be zero, we get :

aθ(t0)= aθ(tN)tN0a(t)(f(h(t), t, θt))θ(t)dtaθ(t0)=tN0a(t)(f(h(t), t, θt))θ(t)dt

or in other words:

dLdθ=tN0a(t)(f(h(t), t, θt))θ(t)dt

With this, we have officially back propagated through the entire network! Even if the proofs and derivations were straight forward, (and hopefully intuitive), it would still be tough to see how do we implement this in practice. There are more nuances to how we actually do it, and hence we shall go over a numerical example and precisely follow the process of forward and backward porpagation to get a complete and thorough understanding of the concept.

Solving a Neural-ODE numerically

Though the above given equations might help in understanding how we can model a neural network using differential equations, it might not be enough to build up the intuition to understand the entire pipeline or how an input flows through the model, and subsequently how we back-propagate through the network to arrive at the gradients. In order to build stronger foundation of Neural ODEs, I would go through a numerical example in order to fully grasp this concept. While the back-propgation method was derived in the last section, the practical implementation is handled a little differently (it could be said we build the theory in the last section, and here we apply it). Let us begin by defining our input, lets call it y0, which is an n-dimensional vector.

y0=[y1yn]1*n

Now we must define a function f, which parametrises the change in our input. In this case (and also to keep the calculations simple), I will simply utilise a dense layer :

f(y0, t, θ)=σ(y0𝐖n* n+𝐁1* n), where θ={𝐖n* n, 𝐁1* n} and σ=ReLU 𝐖n* n=[w11w1nw21w31wn1wnn], 𝐁1* n=[b1b2b3bn]

In the above equation, t simply represents the time steps that we want the differential equation to take, and would be represented by T+1-dimensional vector (0-index), with equidistant points starting from 0. Since these Neural ODEs are primarily designed to model sequential data (which evolved through time), the dimensions of the input and the output remain the same, as the structure or state of the input never changes.

t=[t0t1t2tT]1*(T+1)

Our next step is to simply get to the output, which we will call y1. Recall from the previous sections that since we are going through T time steps, and in a continous manner, we shall integrate from 0 -> T, using our function f, and the input y0.

y1=t0t1f(x(t), t, θt) dty1=0Tσ([𝐲𝟏𝐲𝐧][w11w1nw21w31wn1wnn]+[b1b2b3bn])dt

Here any method can be used to solve the integral, like the Range-Kutta method or Euler discretization. The choice is left to the programmer.

Ounce we get our output, the next logical step would be to define and utilise a Loss function. While the choice of losses is vast, I have used the most basic one: MSE or mean-squared error for simplicity. Let L be our loss function and true be our truth vector or label.

true=[tr0tr1tr2trn]1* nL(y1)=([tr0tr1tr2trn]1* n0Tσ([𝐲𝟏𝐲𝐧][w11w1nw21w31wn1wnn]+[b1b2b3bn])dt)2

Now we need to get the gradient of the Output with respect to the loss, which can be very straightforward to calculate due to the simplistic nature of mean-squared error.

dLdy1=d(truey1)2dy1=2(truey1)

The next step is just the same as we discussed in the previous section: we backpropagate through integrating backwards in time. Remember that our Loss function is essentially takes in three parameters, the initial input y0, the time vector t and the function f, which defines the dynamics our input goes through. (In programming, the function passed in can be a class). This means we can calculate the gradients with respect to loss for each of these variables passed in, and would thus have to define an adjoint state for each one of them. Ounce we get the adjoint state, we find it’s derivative with respect to time and integrate in order to get our final gradients (as was done in the theoritical section). In order to simplify the calculation and to avoid using the ODE solver again and again, we define an augmented state, which is just a vector made up of the variables of which we want a gradient of. We thus also define an augmented adjoint state, which will contain the adjoint dynamics of our variables.

augmented state: [𝐲θt]

How this state vector changes through time is governed by :

ddt[𝐲(t)θ(t)t(t)]=faug([𝐲(t)θ(t)t(t)])=faug([f(y(t), t, θt)𝟎1])

The zero and one vector is because the time vector is differentiated with respect to time itself, and the parameters (θ), do not change during a forward pass. Following this result, we define an adjoint vector with the same properties :

𝐚aug=[𝐚y𝐚θ𝐚t]=[dLdy(t)dLdθ(t)dLdt(t)]T

Another important equation is the derivative of the faug with respect to our augmented vector:

faug[𝐲𝛉𝐭]=[f𝐲f𝛉f𝐭𝟎𝟎𝟎𝟎𝟎𝟎]

Going further, we simply get the derivative of the augmented adjoint state:

From,da(t)dt=a(t)(f(y(t), t, θt))y(t)we get,d𝐚aug(t)dt=[𝐚y𝐚θ𝐚t]T[f𝐲f𝛉f𝐭𝟎𝟎𝟎𝟎𝟎𝟎]=[𝐚yf𝐲𝐚θf𝛉𝐚tf𝐭](t)

The reason we derived the gradient of the output, y1 with respect to loss (dL/dy1) is because it would be passed in the initial augmented state vector, to calculate the value of (dL/y0). Although this value may not be used.

Since we do not need the first and final element in the output vector, we can simply ignore it. The first can be used to obtain the gradient with respect to the hidden states, the second can be used to obtain the gradient with respect to our parameters and the final is used to get the gradient with respect to the time vector. Let’s focus on the second term, which describes the dynamics of the adjoint state of our parameters. Our final gradient for the parameters is obtained by integrating backwards in time, where we go from T to 0.

dLdθ=tN0a(t)(f(h(t), t, θt))θ(t)dt=T0𝐚θf𝛉dt=T0dLdθ(t)fθdt=T0d(trueσ(y0𝐖n*n+𝐁1*n))2d[𝐖𝐁]σ(y0𝐖n*n+𝐁1*n)[𝐖𝐁]dt

With this, we have finally found the gradient with respect to the parameters! Since there were a lot of equations, our primary focus of understanding the flow of inputs and gradients can be a little foggy. Here is the entire process, step by step to get a clearer understanding.

Conclusion

Neural ODEs are sometimes also referred to as liquid neural networks. This method offers a different approach to solving problems for which neural networks were useful for. This method was quite popular when it came out in 2018 (It was awarded the best paper in NeurIPS) and yet failed to see adoption in mainstream models. There can be various reasons why this might’ve happened, one being this models (like SSMs) are more suited for continuous datasets, and can model data better when the underlying generator of data is continuous in nature. This architecture can also be viewed as a subset of a larger Physics-Informed neural networks, and I am hoping to explore this topic more! Thank you.

Home