% Environments
\newcommand{\al}[1]{\begin{align}#1\end{align}} % need this for \tag{} to work
\renewcommand{\r}{\mathrm} % BAD!! does cursed things with accents :((
% Delimiters
% (I needed to create my own because the MathJax version of \DeclarePairedDelimiter doesn't have \mathopen{} and that messes up the spacing)
% .. one-part
\newcommand{\p}[1]{\mathopen{}\left( #1 \right)}
\renewcommand{\b}[1]{\mathopen{}\left[ #1 \right]}
\newcommand{\set}[1]{\mathopen{}\left\{ #1 \right\}}
\newcommand{\abs}[1]{\mathopen{}\left\lvert #1 \right\rvert}
\newcommand{\floor}[1]{\mathopen{}\left\lfloor #1 \right\rfloor}
\newcommand{\ceil}[1]{\mathopen{}\left\lceil #1 \right\rceil}
\newcommand{\inner}[1]{\mathopen{}\left\langle #1 \right\rangle}
\newcommand{\norm}[1]{\mathopen{}\left\lVert #1 \strut \right\rVert}
\newcommand{\mix}[1]{\mathopen{}\left\lfloor #1 \right\rceil}
%% .. two-part
\newcommand{\inco}[2]{#1 \mathop{}\middle|\mathop{} #2}
\newcommand{\co}[2]{ {\left.\inco{#1}{#2}\right.}}
\newcommand{\cond}{\co} % deprecated
\newcommand{\at}[2]{ {\left.#1\strut\right|_{#2}}}
\newcommand{\para}[2]{#1\strut \mathop{}\middle\|\mathop{} #2}
% Greek
% the following cause issues with real LaTeX tho :/ maybe consider naming it \fhi instead?
\let\fi\phi % because it looks like an f
\let\phi\varphi % because it looks like a p
% Miscellaneous
% .. operators
\DeclareMathOperator*{\argmin}{arg\thinspace min}
\DeclareMathOperator*{\argmax}{arg\thinspace max}
% .. functions
% .. analysis
\newcommand{\df}[2]{ {\f{\d #1}{\d #2}}}
\newcommand{\ds}[2]{ {\sl{\d #1}{\d #2}}}
\newcommand{\ddf}[3]{ {\f{\dd{#1} #2}{\p{\d #3}^{#1}}}}
\newcommand{\dds}[3]{ {\sl{\dd{#1} #2}{\p{\d #3}^{#1}}}}
\newcommand{\partf}[2]{\f{\part #1}{\part #2}}
\newcommand{\parts}[2]{\sl{\part #1}{\part #2}}
% .. sets
\newcommand{\Rge}{\R_{\ge 0}}
\newcommand{\Rgt}{\R_{> 0}}
\newcommand{\pmo}{\set{\pm 1}}
\newcommand{\zpmo}{\set{0,\pm 1}}
% .... set operations
\newcommand{\inc}[1]{\union \set{#1}} % "including"
\newcommand{\exc}[1]{\setminus \set{#1}} % "except"
% .. over and under
\newcommand{\tld}{\widetilde} % deprecated
\newcommand{\HAT}{\widehat} % deprecated
\newcommand{\rt}[1]{ {\sqrt{#1}}}
% .... two-part
\renewcommand{\sl}[2]{#1 /\mathopen{}#2}
% .. arrows
% .. operators and relations
\newcommand{\OX}[1]{^{\ox #1}}
% .. punctuation and spacing
% Levels of closeness
% .. vanilla versions (is it within a constant?)
% .. dotted versions (is it equal in the limit?)
% .. log versions (is it equal up to log?)
% Logic and bit operations
\DeclareMathOperator{\1}{\mathbb{1}} % use \mathbbm instead if using real LaTeX
% Linear algebra
\newcommand{\spn}{\mathrm{span}} % do NOT use \span because it causes misery with amsmath
% .. named tensors
\newcommand{\namedtensorstrut}{\vphantom{fg}} % milder than \mathstrut
\newcommand{\name}[1]{\mathsf{\namedtensorstrut #1}}
\newcommand{\nbin}[2]{\mathbin{\underset{\substack{#1}}{\namedtensorstrut #2}}}
% Probability
% .. operators
% ... information theory
% .. other divergences
% Complexity classes
% .. keywords
% .. classical
% .. probabilistic
% .. circuits
% .. resources
% .. custom
% Boolean analysis
% \newcommand{\Exp}[1]{\operatorname{E}_{#1}\mathopen{}}
\DeclareMathOperator{\CDT}{\mathrm{CDT}} % canonical
\DeclareMathOperator{\PDT}{\mathrm{PDT}} % partial decision tree
% .. functions (small caps sadly doesn't work)
% Dynamic optimality
% Alignment
% In "text"
% remove these last two if using real LaTeX
% Fonts
% .. bold
% .. calligraphic
% .. typewriter
Personal summary of Exact solutions to the nonlinear dynamics of learning in deep linear neural networks.
This paper shows you can decompose learning dynamics of a fully linear network into some (mostly independent) singular modes that learn at different rates (depending on the eigenvalues of the input-output correlation matrix). In fact, it’s as if the network is “simulating” another network for which the correlation between input and output coordinates is diagonal.
General setup
The model is $\ndef{\input}{input}\ndef{\hidden}{hidden}\ndef{\output}{output}\ndef{\data}{data}\newcommand{\We}{W^\mathrm{e}}\newcommand{\Wd}{W^\mathrm{d}}x\ndot\input\We\ndot\hidden \Wd$ (using named tensor notation), where
- $\We \in \R^{\input \times \hidden}$ “encodes” the input $x$ into a hidden layer,
- $\Wd \in \R^{\hidden\times \output}$ “decodes” the hidden layer into the output,
and both $\We$ and $\Wd$ are initialized as independent Gaussians.
It is trained on $N$ data points:
- inputs: $X \ce (x^{(1)}, \ldots, x^{(n)}) \in \R^{\data \times \input}$,
- outputs: $Y \ce (y^{(1)}, \cdots, y^{(n)}) \in \R^{\data \times \output}$.
Which means the square loss is
%\ce \sum_{d=1}^D \norm{y^{(d)}-\Wd \ndot\hidden \We\ndot\input x^{(d)}}_\output^2
\ce \sum_\data\norm{Y - X \ndot\input \We\ndot\hidden \Wd}_\output^2.
Taking the differential, we get
\partial \CL
&= \sum_\data 2 \p{Y - X \ndot\input \We\ndot\hidden \Wd}\ndot\output\partial\p{Y - X \ndot\input \We\ndot\hidden \Wd}\\
&= -2 \sum_\data \p{Y - X \ndot\input \We\ndot\hidden \Wd}\ndot\output\b{X\ndot\input\p{\partial\We \ndot\hidden \Wd + \We \ndot\hidden \partial\Wd}}\\
&= -2 \b{X\ndot\data\p{Y - X \ndot\input \We\ndot\hidden \Wd} } \ndot{\input\\\output}\p{\partial\We \ndot\hidden \Wd + \We \ndot\hidden \partial\Wd}.
Let’s assume that $X \ndot\data X = I_\input$ (i.e. the input coordinates have been standardized to have mean $0$, variance $1$ and be uncorrelated). Then we have
%\CE =
X\ndot\data\p{Y - X \ndot\input \We\ndot\hidden \Wd} = \UB{X\ndot\data Y}_{\ec\ \Sigma} - \We \ndot\hidden \Wd,
where the “input-output correlations” matrix $\Sigma \ce X \ndot\data Y \in \R^{\input \times \output}$ completely captures the problem. The matrix product $\We \ndot\hidden \Wd$ is “trying” to match these correlations $\Sigma$, and the learning dynamics are given by
\frac{\d\We}{\d t} = -\frac{\partial\CL}{\partial\We} &= 2\p{\Sigma - \We\ndot\hidden\Wd}\ndot\output\Wd\\
\frac{\d\Wd}{\d t} = -\frac{\partial\CL}{\partial\Wd} &= 2\p{\Sigma - \We\ndot\hidden\Wd}\ndot\input\We.
We’ll drop those factors $2$ for simplicity.
Diagonal case
Suppose that $|\input| = |\output|\ec d$, and that $\Sigma$ is diagonal with elements $s_1, \ldots, s_n$. Then we can decompose the dynamics into (somewhat) independent parts if we group the weights by which input/output coordinate they connect to:
\frac{\d\We_{\input(i)}}{\d t} &= \p{s_i - \We_{\input(i)}\ndot\hidden\Wd_{\output(i)}}\Wd_{\output(i)} - \sum_{j \ne i} \p{\We_{\input(i)}\ndot\hidden\Wd_{\output(j)}}\Wd_{\output(j)}\\
\frac{\d\Wd_{\output(i)}}{\d t} &= \UB{\p{s_i - \We_{\input(i)}\ndot\hidden\Wd_{\output(i)}}\We_{\input(i)}}_\text{``feature benefit''} - \UB{\sum_{j \ne i} \p{\We_{\input(j)}\ndot\hidden\Wd_{\output(i)}}\We_{\input(j)}}_\text{``interference''}.
That is, there are two forces:
- feature benefit: the dot product between $\We_{\input(i)}$ and $\Wd_{\output(i)}$ is “trying” to be $s_i$ (they tend to align),
- interference: $\We_{\input(i)}$ and $\Wd_{\output(j)}$ are trying to be perpendicular when $i \ne j$ (they tend to repell each other).
This is very similar to the setup in Toy Models of Superposition, except that
- there are no ReLUs at the end,
- the encoding and decoding weights $\We$ and $\Wd$ are not tied together,
- the target values $s_i$ can be different from $1$.
Speed of learning
To simplify things, let’s assume that all $s_i$ are positive, $\We$ and $\Wd$ start out with
- $\We_{\input(i)}$ perpendicular to $\Wd_{\output(j)}$ when $i \ne j$,
- $\We_{\input(i)}=\Wd_{\output(i)}$.
Then interference disappears, and $\We_{\input(i)}$ will remain equal to $\Wd_{\output(i)}$ throughout training. Because of this, we’ll just refer to them both as $W_i$. Introducing a new variable $u_i \ce W_i\ndot\hidden W_i = \norm{W_i}_\hidden^2 \in \R$ to track their squared norm, we get
\frac{\d u_i}{\d t}
&= 2\frac{\d W_i}{\d t}\ndot\hidden W_i\\
&= 2\b{\p{s_i - W_i\ndot\hidden W_i}W_i}\ndot\hidden W_i\\
&= 2\p{s_i - u_i}u_i.
This is the differential equation of a logistic function: it solves to
u_i(t) = \frac{s_i}{1+e^{-2s_i(t-t_0)}}.
This means that $u_i$ takes time
- $\approx \log(s_i/\eps)/s_i$ to go from $\eps$ to $s_i/2$;
- $\approx \log(s_i/\eps)/s_i$ to go from $s_i/2$ to $s_i-\eps$.
In other words, the rate at which the values of the correlation matrix $\Sigma$ get learned is (roughly) proportional to those values!
Speed of lowering interference
This section is my own speculation.
Let’s keep the assumption that $\We_{\input(i)}=\Wd_{\output(i)}$ at the start (and therefore, throughout training), but relax the assumption that $W_i$ starts out perpendicular to $W_j$ when $i \ne j$. Then we can study at which rate the interference decreases.
It turns out that the right quantity to look at is the cosine similarity
\eps_{ij} \ce \frac{W_i \ndot\hidden W_j}{\norm{W_i}_\hidden \norm{W_j}_\hidden}.
Indeed, $\eps_{ij}$ doesn’t change when $W_i$ or $W_j$ gets scaled up by some factor (without changing direction), so the “feature benefit” part doesn’t affect $\eps_{ij}$: for the purposes of studying how fast $\eps_{ij}$ decreases at any point in time, the only relevant changes in $W_i$ are
\frac{\d W_i}{\d t} = -\sum_{j \ne i}\p{W_i \ndot\hidden W_j}W_j.
Then we have
&\frac{\d \p{W_i \ndot\hidden W_j}}{\d t}\\
&\quad= \frac{\d W_i}{\d t}\ndot \hidden W_j + W_i \ndot\hidden \frac{\d W_j}{\d t}\\
&\quad= -\p{W_i \ndot \hidden W_j}\p{\norm{W_i}_\hidden^2 + \norm{W_j}_\hidden^2} - \UB{2\sum_{k \ne i,j}\p{W_i \ndot \hidden W_k}\p{W_j \ndot \hidden W_k}}_\text{second-order effects}
Ignoring the second-order effects, we get
\frac{\d \eps_{ij}}{\d t} \approx -\eps_{ij}\p{\norm{W_i}_\hidden^2 + \norm{W_j}_\hidden^2}.
So $\eps_{ij}$ decreases exponentially at a rate proportional to the larger of the two squared norms $\norm{W_i}_\hidden^2$ and $\norm{W_j}_\hidden^2$. This means that the angle between $W_i$ and $W_j$ will start to approach $90\degree$ faster when one of $W_i$ or $W_j$ gets big. And that, as we’ve seen, depends on how big $s_i$ and $s_j$ are.
Summary of the dynamics
So we can summarize the dynamics of learning this way:
- the input-output correlations get learned one by one, in order of their goal value $s_i$;
- when $\We_{\input(i)}$ and $\Wd_{\output(i)}$ get learned, they quickly align and acquire the same length ($\We_{\input(i)} \approx \Wd_{\output(i)}$ if $s_i>0$ and $\We_{\input(i)} \approx -\Wd_{\output(i)}$ if $s_i < 0$);
- they eventually reach length $\sqrt{|s_i|}$;
- when $\We_{\input(i)}$/$\Wd_{\output(i)}$ get learned, they simultaneously repel all the other vectors into their perpendicular space.
Non-diagonal case
The striking thing with fully linear networks is that they’re completely rotationally invariant. For us, this means that even if $\Sigma$ is not diagonal, things will still happen the exact same way, but over the singular value decomposition of $\Sigma$ instead.
Suppose that $|\input| = |\output|\ec d$. Let’s rewrite $S$ as a sum of some singular modes (some of which might be $0$ if $\Sigma$ is not full-rank):
\Sigma = \sum_{i=1}^D s_i(U_{\inmode(i)} \odot V_{\outmode(i)}) = U \ndot\inmode S \ndot\outmode V
where $U \in \R^{\inmode\times\input}$, $V \in \R^{\outmode \times \output}$ are orthonormal bases of $\R^{\input}$ and $\R^{\output}$, and $S \in \R^{\inmode \times \outmode}$ is diagonal with values $s_1, \ldots, s_n$. Let’s also represent $\We$ and $\Wd$ under the new bases:
\We = U \ndot\inmode\ol\We \qquad \Wd = \ol\Wd \ndot\outmode V.
Then we have
\partial \CL
&= 2\p{\Sigma - \We \ndot\hidden \Wd}\ndot{\input\\\output}\p{\partial\We \ndot\hidden \Wd + \We \ndot\hidden \partial \Wd}\\
&= 2\b{U \ndot{\inmode}\p{S - \ol\We \ndot\hidden \ol\Wd}\ndot\outmode V}\\
&\qquad\ndot{\input\\\output}\b{U \ndot\inmode \p{\partial\ol\We \ndot\hidden \ol\Wd + \ol\We \ndot\hidden \partial \ol\Wd} \ndot\outmode V}\\
&=2\p{S - \ol\We \ndot\hidden \ol\Wd}\ndot{\inmode\\\outmode}\p{\partial\ol\We \ndot\hidden \ol\Wd + \ol\We \ndot\hidden \partial \ol\Wd},
(where the last equality is by orthonormality of $U,V$) so the new learning dynamics are
\frac{\d\ol\We}{\d t} = -\frac{\partial\CL}{\partial\ol\We} &= 2\p{S - \We\ndot\hidden\Wd}\ndot\outmode\ol\Wd\\
\frac{\d\ol\Wd}{\d t} = -\frac{\partial\CL}{\partial\ol\Wd} &= 2\p{S - \ol\We\ndot\hidden\ol\Wd}\ndot\inmode\ol\We.
This is exactly the same as what we had earlier, except that $\Sigma, \We, \Wd$ changed to $S, \ol\We,\ol\Wd$ respectively, and $S$ is guaranteed to be diagonal. That is, by performing a simple change of basis, we’ve completely reduced to the diagonal case over a new “virtual” linear neural network.
Also note that by rotational symmetry, if the initializations of $\We$ and $\Wd$ are i.i.d. Gaussians, then so are the initializations of $\ol\We$ and $\ol\Wd$.