Personal summary of Attention Is All You Need by Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, and Polosukhin, with a lot of inspiration from A Mathematical Framework for Transformer Circuits by Elhage, Nanda, Olsson et al. This is more of an opinion piece than a faithful description. Also it’s a work in progress; I’m probably misunderstanding a lot of things (even really basic and important things), and I don’t have a good sense of scale.
Simplified structure
Essentially transformers look like this:
#figure draw three gray question marks instead of one
Almost all of the components work are local to one token and work exactly the same way for each token. The only exception is the attention part, which moves information across tokens, and can choose to move different information depending on distance (this is illustrated by arrows of different colors).
The model takes a sequence of tokens $\newcommand{\dtokens}{ {d_\mathrm{tokens}}}\p{t^{(1)}, \ldots, t^{(n)}} \in [\dtokens]$ (say, $\dtokens = 50000$), and outputs a guess for the next token $t^{(n+1)}$ in the form of a vector $\tau^{(n+1)} \in \R^{\dtokens}$, where the coordinates of $\tau^{(n+1)}$ are logits. Logits means log odds: $\tau^{(n+1)}_i$ is the log of the odds of the next token being $i$. That is, according to the model,
\Pr\b{t^{(n+1)}=i} = \frac{\exp\p{\tau_i^{(n+1)}}}{\sum_{j=1}^\dtokens \exp\p{\tau_j^{(n+1)}}}.
Because they’re working additively on log odds, transformers can easily perform Bayesian reasoning, which is all about working multiplicatively on odds.
In fact, the model doesn’t give logits for the $(n+1)\nth$ token only: it also makes a logit prediction $\tau^{(i+1)}$ based on each prefix $\p{t^{(1)}, \ldots, t^{(i)}}$ ($i \in [n]$) of the sequence, which describes its guess for the $(i+1)\nth$ token based on seeing just the first $i$ tokens. This is somehow more natural because of how its structured internally, and allows it to be trained and (as far as I can tell) run more quickly.
Encoding and decoding
The encoder transform the tokens $t^{(i)}$ into “state vectors” in a smaller dimensional space $\newcommand{\dstate}{ {d_\mathrm{state}}}\R^{\dstate}$ (say $\dstate = 512$), in a such way that spatial directions carry semantic information. For example, there could be directions correlating with
- gender of a noun
- tense of a verb
- positive/negative feeling
- words related to water vs dryness
- etc.
As far as I can tell, it would have been possible to use state vectors in $\R^{\dtokens}$ (with tokens encoded one-hot), but that makes running/training the transformer much more expensive, while the encoding is very cheap to pre-train and doesn’t lose much relevant information.
These state vectors will be maintained in a residual stream (indicated by the thick black line). They will act as a sort of scratchpad that can be read from and written to, accumulating information and making inferences from the current token as well as past tokens, in order to figure out what the next token is.
The decoder takes the result from $\R^{\dstate}$ back into the logits $\tau^{(i+1)} \in \R^{\dtokens}$.
In between the encoder and the decoder, the state vectors in $\R^\dstate$ get transformed through a series of $N=6$ layers. Each layer has the same structure (but independent weights): an attention module followed by a perceptron module.
Attention module
The role of attention is to move information forward in time between states corresponding to different tokens, i.e. it moves information from left to right. Information never flows back in time.
Say that the states coming into the attention layer are $x^{(1)}, \ldots, x^{(n)} \in \R^{\dstate}$. Then each state $x^{(i)}$ will perform “queries” on the previous tokens’ states $x^{(i-1)}$, $x^{(i-2)}$, etc, and when there is a “match”, it will copy-and-paste some value from that state.
More precisely, for each state $x^{(i)}$ and each previous token’s state $x^{(i-k)}$, it will do several queries of the following form (called “attention heads”):
- match: compute the product $p =\p{x^{(i-k)}}^\transp M x^{(i)} \in \R$, for some fairly low-rank matrix $M \in \R^{\dstate \times \dstate}$,
- the column space of $M$ describes what the query is looking in $x\P{i-k}$, while its row space determines whether $x\P{i}$ is interested in receiving the result of the query,
- copy-paste: if $p$ is large enough, it will add $Vx^{(i-k)}$ to $x^{(i)}$, where $V \in \R^{\dstate \times \dstate}$ is another low-rank matrix
- the row space of $V$ describes what information the copy is pulling from $x\P{i-k}$, while its column space of $V$ describes where to paste it in $x\P{i}$.
The values of $M$ and $V$ differ between queries, and they can also differ based on the distance $k$ between them within the sequence. But they don’t depend on $i$: attention is “translation-invariant”.
Perceptron module
The perceptron module is just a neural network with one hidden layer (with $\newcommand{\dff}{ {d_\mathrm{per}}}\dff=2048$ neurons) and ReLU activations. That is, for each state vector $x^{(i)}$, its output is
y^{(i)} =W_2\ \ReLU(W_1x^{(i)}+b_1)+b_2
with $W_1 \in \R^{\dff\times \dstate}$, $W_2 \in \R^{\dstate \times \dff}$, $b_1 \in \R^\dff$, and $b_2 \in \R^{\dstate}$. This output $y^{(i)}$ is then added back to $x^{(i)}$.
Modifications and omitted details
Positional encoding
- explain how the positional encodings allow specialization of the heads based on distance (or analysis purposes we might as well take this specialization as an explicit part of the architecture rather than something that has to be found by gradient descent)
- how cursed tokens are and detokenization (reforming words / phrases that are split across several tokens)
Layer norm
#to-write In actuality, things get renormed, and since things get added over the layers, the earlier computations kind of “decay” in importance
Most of the intuitions are from the Neel Nanda ML street talk interview.
The role of attention
In previous models like convolutional networks, the architecture dictated how to combine information between positions (e.g. do this many convolutions with this kernel side, this stride, etc.). Instead, transformers devote a fraction of their parameters to figuring out how to move information, and let the model figure out what is most useful.
The residual stream is a bottleneck
The dimensions of the residual stream are a very precious resource.
- In GPT-2 small, there are 50000 different tokens, and each MLP layer has 3000 layers, yet this must be compressed into just a 768-dimensional residual stream.
- In the Othello paper, the model uses 128 out of 512 dimensions to represent the Othello board, but probably only happens because the problem is set up make representing the game board insanely important. In normal situations, transformers may rely on other tricks, like using attention heads, to avoid using up residual dimensions (e.g. GPT-4 seems to be able to play semi-valid chess by just attending to previous moves).
- You can’t (say) double the size of the residual stream without doubling the number of parameters overall, since it doubles the size of the matrices you need for computing the queries/keys/values in each head, and it doubles the number of weights going into each neuron in the feed-forward layer.
Depth vs width trade-offs
Most computations don’t need all $2N$ layers of nonlinearities and can just sit in the residual stream for a while, so it doesn’t actually change that much if you e.g. halve the number of layers but double the number of attention heads / feed-forward neurons in each layer. The total number of parameters is what matters the most.
- make a table of sizes / neuron counts / weight counts estimates for the original transformer, GPT-2, GPT-3, etc.
- tie in with scaling laws