Suppose a product w1wd is trained to approach a target value s, with gradient descent. How fast will it approach, and how does this depend on the degree d?

Setup

More concretely, let’s define a squared loss

L:=12(w1wds)2,

and make each “weight” wi evolve over time according to its gradient

dwidt=Lwi=(sw1wd)jiwj.

This corresponds to what would happen if we were somehow training a linear network of depth d and width 1 on a 1D linear regression task where the ground truth is f(x)=sx. But more generally, as shown in Deep linear networks, this also roughly describes the dynamics with which linear networks of any width learn the singular modes of the linear relationship between inputs and outputs.

Degree d=2

When the degree is 2, we have just

{dw1dt=(sw1w2)w2dw2dt=(sw1w2)w1.

Phase 1: growth

Let’s first think about the early stage, when w1w2s/2, where we have

{dw1dt=Θ(sw2)dw2dt=Θ(sw1).

Intuitively, we should expect w1 and w2 to approach each other (at least in relative terms): indeed, suppose that initially w1w2, then that would cause w2 to grow much faster than w1 until it mostly catches up to it. We describe this process of catching up in more details in the section How fast do weights equalize? below, but for now let’s just assume that w1 and w2 have caught up, which means there is now just a single variable w=w1=w2. The early stage corresponds to w2s/2, and

dwdt=Θ(sw)d(w2)/dtw2=Θ(s)

The relative derivative is constant, so this solves to an exponential w2(t)=eΘ(st)w2(0), which means that it would take time Θ(log(s/ϵ)s) to go from a small value w2=ϵ to the halfway point w2=s/2.

Phase 2: approach

In the later stage of learning, when w2s/2, we have

d(w2)dt=2(sw2)w2,

so

d(sw2)/dtsw2=2w2=Θ(s),

which solves to sw2(t)=eΘ(s)t(sw2(0)), which means that it would take time Θ(log(1/ϵ)s) to go from the halfway point w2=s/2 to w2=s(1ϵ).

Learning rate

#to-write

  • derive the maximum learning rate
  • adapt the conclusion in term of training steps
  • but think about how this would change things with width >1 when there’s several different objectives with different singular values s: make the point that larger singular values are generally favored (which is not the case when d=1)
  • (note that in Deep linear networks we didn’t worry about the learning rate because it was all relative between different singular values; on the other hand you can’t make claims about (absolute) acceleration without worrying about those learning rates)

General case

In general, similar dynamics will force the weights w1,,wd to approach each other in relative terms,1 and we would be left with one variable w with

dwdt=(swd)wd1d(wd)dt=d(swd)w2d2.

Phase 1: growth

When wds/2, treating d as a constant, we have

d(wd)/dtwd=Θ(swd2)=Θ(s(wd)12/d).

Depending on the sign of the exponent 12/d in this relative derivative, wd will experience different types of growth:

  • If d<2, then wd will grow polynomially. In particular, the only relevant case is d=1, and gives w(t)=w(0)+Θ(s). So going from a small value w0 to the halfway point w=s/2 takes Θ(1) time.
  • If d=2, as seen above, w2 grows exponentially, and going from ϵ to s/2 takes Θ(log(s/ϵ)s) time.
  • If d>2, then wd will grow hyperbolically as wd(t)=Θ(1s(tt0))1+2d2, and going from a small value wd=ϵ to the halfway point wd=s/2 takes Θ(ϵ(12/d)s) time.

In particular, if wd is starting from a fairly low value, degrees d>2 will take a long time to take off, and in general the trade-off depends on s.

Phase 2: approach

When wds/2, treating d as a constant, we have

d(swd)/dtswd=Θ(w2d2)=Θ((wd)22/d)=Θ(s22/d),

so in any case the approach is exponential, with

swd(t)=eΘ(s22/dt)(swd(0)),

which means that going from the halfway point wd=s/2 to wd=(1ϵ)s takes

Θ(log(1/ϵ)s22/d)

time. In this phase, bigger degree is always better.

Learning rate

#to-write same as above

Nonlinearities

Now suppose that we’re trying to approach s with a monomial like w12w2, where some of the weights are squared, and the loss is still the square

L:=12(w12w2s)2.

This is roughly the situation that would arise if you’re trying to approximate function f(x)=sx2 using a depth-2 network with a quadratic nonlinearity ϕ(h):=h2 on its hidden layer.

Then the dynamics are

{dw1dt=2(sw12w2)w1w2dw2dt=(sw12w2)w12.

Again, in the growth phase where w12w2s, we have

{dw1dt=Θ(sw1w2)dw2dt=Θ(sw12),

so again, w1 and w2 will tend to equalize, because if e.g. w1w2, then w2 will grow much faster. So we can approximate the dynamics by

dwdt=Θ((sw3)w2).

Not too surprisingly, this is exactly the same dynamics that we got for a the problem of approximating s by w1w2w3. So, depending on how a nonlinearity acts, it seems like it could make a neural network act like it has more depth than it actually has!

How fast do weights equalize?

Let’s study how fast the ratio w1/w2 approaches 1 when weights are subject to the dynamics

{dw1dt=(sw1w2)w2dw2dt=(sw1w2)w1.

For this, it’s convenient to look at the relative derivative of both the ratio w1/w2 itself, but also compare it with the relative derivative of their product w1w2, since it will allow us to compare the timescales at which equalizing vs learning happen.

We have

d(w1w2)/dtw1w2=dw1/dtw1+dw2/dtw2=(sw1w2)(w2w1+w1w2)d(w1/w2)/dtw1/w2=dw1/dtw1dw2/dtw2=(sw1w2)(w2w1w1w2),

which at the least confirms that if w1>w2 then w1/w2 is decreasing over time, and vice versa. Let’s assume we start out with w1w2>0, then we have

d(w1w2)/dtw1w2(sw1w2)w1w2d(w1/w2)/dtw1/w2(sw1w2)w1w2.

That is, w1w2 increases at the same relative rate as w1/w2 decreases, which means (to a first approximation), we can expect that w1 and w2 will roughly equalize as long as we start out with

sw1w2how much w1w2 needs to groww1w2how much w1/w2 needs to decreasew1s.
  1. or at least, it seems like their absolute values would