$\require{mathtools} % %%% GENERIC MATH %%% % % Environments \newcommand{\al}[1]{\begin{align}#1\end{align}} % need this for \tag{} to work \renewcommand{\r}{\mathrm} % % Greek \newcommand{\eps}{\epsilon} \newcommand{\veps}{\varepsilon} \newcommand{\Om}{\Omega} \newcommand{\om}{\omega} \newcommand{\Th}{\Theta} \let\fi\phi % because it looks like an f \let\phi\varphi % because it looks like a p % % Miscellaneous shortcuts % .. over and under \newcommand{\ss}[1]{_{\substack{#1}}} \newcommand{\ob}{\overbrace} \newcommand{\ub}{\underbrace} \newcommand{\ol}{\overline} \newcommand{\tld}{\widetilde} \newcommand{\HAT}{\widehat} \newcommand{\f}{\frac} \newcommand{\s}[2]{#1 /\mathopen{}#2} \newcommand{\rt}[1]{ {\sqrt{#1}}} % .. relations \newcommand{\sr}{\stackrel} \newcommand{\sse}{\subseteq} \newcommand{\ce}{\coloneqq} \newcommand{\ec}{\eqqcolon} \newcommand{\ap}{\approx} \newcommand{\ls}{\lesssim} \newcommand{\gs}{\gtrsim} % .. miscer \newcommand{\q}{\quad} \newcommand{\qq}{\qquad} \newcommand{\heart}{\heartsuit} % % Delimiters % (I needed to create my own because the MathJax version of \DeclarePairedDelimiter doesn't have \mathopen{} and that messes up the spacing) % .. one-part \newcommand{\p}[1]{\mathopen{}\left( #1 \right)} \newcommand{\b}[1]{\mathopen{}\left[ #1 \right]} \newcommand{\set}[1]{\mathopen{}\left\{ #1 \right\}} \newcommand{\abs}[1]{\mathopen{}\left\lvert #1 \right\rvert} \newcommand{\floor}[1]{\mathopen{}\left\lfloor #1 \right\rfloor} \newcommand{\ceil}[1]{\mathopen{}\left\lceil #1 \right\rceil} \newcommand{\inner}[1]{\mathopen{}\left\langle #1 \right\rangle} % .... (use phantom to force at least the standard height of double bars) \newcommand{\norm}[1]{\mathopen{}\left\lVert #1 \vphantom{f} \right\rVert} \newcommand{\frob}[1]{\norm{#1}_\mathrm{F}} %% .. two-part \newcommand{\incond}[2]{#1 \mathop{}\middle|\mathop{} #2} \newcommand{\cond}[2]{ {\left.\incond{#1}{#2}\right.}} \newcommand{\pco}[2]{\p{\incond{#1}{#2}}} \newcommand{\bco}[2]{\b{\incond{#1}{#2}}} \newcommand{\setco}[2]{\set{\incond{#1}{#2}}} \newcommand{\at}[2]{ {\left.#1\right|_{#2}}} \newcommand{\pat}[2]{\p{\at{#1}{#2}}} \newcommand{\bat}[2]{\b{\at{#1}{#2}}} % ..... (use phantom to force at least the standard height of double bar) \newcommand{\oldpara}[2]{#1\vphantom{f} \mathop{}\middle\|\mathop{} #2} %\newcommand{\para}[2]{#1\vphantom{f} \mathop{}\middle\|\mathop{} #2} \newcommand{\para}[2]{\mathchoice{\begin{matrix}#1\\\hdashline#2\end{matrix}}{\begin{smallmatrix}#1\\\hdashline#2\end{smallmatrix}}{\begin{smallmatrix}#1\\\hdashline#2\end{smallmatrix}}{\begin{smallmatrix}#1\\\hdashline#2\end{smallmatrix}}} \newcommand{\ppa}[2]{\p{\para{#1}{#2}}} \newcommand{\bpa}[2]{\b{\para{#1}{#2}}} %\newcommand{\bpaco}[4]{\bpa{\incond{#1}{#2}}{\incond{#3}{#4}}} \newcommand{\bpaco}[4]{\bpa{\cond{#1}{#2}}{\cond{#3}{#4}}} % % Levels of closeness \newcommand{\scirc}[1]{\sr{\circ}{#1}} \newcommand{\sdot}[1]{\sr{.}{#1}} \newcommand{\slog}[1]{\sr{\log}{#1}} \newcommand{\createClosenessLevels}[7]{ \newcommand{#2}{\mathrel{(#1)}} \newcommand{#3}{\mathrel{#1}} \newcommand{#4}{\mathrel{#1\!\!#1}} \newcommand{#5}{\mathrel{#1\!\!#1\!\!#1}} \newcommand{#6}{\mathrel{(\sdot{#1})}} \newcommand{#7}{\mathrel{(\slog{#1})}} } \let\lt\undefined \let\gt\undefined % .. vanilla versions (is it within a constant?) \newcommand{\ez}{\scirc=} \newcommand{\eq}{\simeq} \newcommand{\eqq}{\mathrel{\eq\!\!\eq}} \newcommand{\eqqq}{\mathrel{\eq\!\!\eq\!\!\eq}} \newcommand{\lez}{\scirc\le} \newcommand{\lq}{\preceq} \newcommand{\lqq}{\mathrel{\lq\!\!\lq}} \newcommand{\lqqq}{\mathrel{\lq\!\!\lq\!\!\lq}} \newcommand{\gez}{\scirc\ge} \newcommand{\gq}{\succeq} \newcommand{\gqq}{\mathrel{\gq\!\!\gq}} \newcommand{\gqqq}{\mathrel{\gq\!\!\gq\!\!\gq}} \newcommand{\lz}{\scirc<} \newcommand{\lt}{\prec} \newcommand{\ltt}{\mathrel{\lt\!\!\lt}} \newcommand{\lttt}{\mathrel{\lt\!\!\lt\!\!\lt}} \newcommand{\gz}{\scirc>} \newcommand{\gt}{\succ} \newcommand{\gtt}{\mathrel{\gt\!\!\gt}} \newcommand{\gttt}{\mathrel{\gt\!\!\gt\!\!\gt}} % .. dotted versions (is it equal in the limit?) \newcommand{\ed}{\sdot=} \newcommand{\eqd}{\sdot\eq} \newcommand{\eqqd}{\sdot\eqq} \newcommand{\eqqqd}{\sdot\eqqq} \newcommand{\led}{\sdot\le} \newcommand{\lqd}{\sdot\lq} \newcommand{\lqqd}{\sdot\lqq} \newcommand{\lqqqd}{\sdot\lqqq} \newcommand{\ged}{\sdot\ge} \newcommand{\gqd}{\sdot\gq} \newcommand{\gqqd}{\sdot\gqq} \newcommand{\gqqqd}{\sdot\gqqq} \newcommand{\ld}{\sdot<} \newcommand{\ltd}{\sdot\lt} \newcommand{\lttd}{\sdot\ltt} \newcommand{\ltttd}{\sdot\lttt} \newcommand{\gd}{\sdot>} \newcommand{\gtd}{\sdot\gt} \newcommand{\gttd}{\sdot\gtt} \newcommand{\gtttd}{\sdot\gttt} % .. log versions (is it equal up to log?) \newcommand{\elog}{\slog=} \newcommand{\eqlog}{\slog\eq} \newcommand{\eqqlog}{\slog\eqq} \newcommand{\eqqqlog}{\slog\eqqq} \newcommand{\lelog}{\slog\le} \newcommand{\lqlog}{\slog\lq} \newcommand{\lqqlog}{\slog\lqq} \newcommand{\lqqqlog}{\slog\lqqq} \newcommand{\gelog}{\slog\ge} \newcommand{\gqlog}{\slog\gq} \newcommand{\gqqlog}{\slog\gqq} \newcommand{\gqqqlog}{\slog\gqqq} \newcommand{\llog}{\slog<} \newcommand{\ltlog}{\slog\lt} \newcommand{\lttlog}{\slog\ltt} \newcommand{\ltttlog}{\slog\lttt} \newcommand{\glog}{\slog>} \newcommand{\gtlog}{\slog\gt} \newcommand{\gttlog}{\slog\gtt} \newcommand{\gtttlog}{\slog\gttt} % % Miscellaneous \newcommand{\LHS}{\mathrm{LHS}} \newcommand{\RHS}{\mathrm{RHS}} % .. operators \DeclareMathOperator{\poly}{poly} \DeclareMathOperator{\polylog}{polylog} \DeclareMathOperator{\quasipoly}{quasipoly} \DeclareMathOperator{\negl}{negl} \DeclareMathOperator*{\argmin}{arg\,min} \DeclareMathOperator*{\argmax}{arg\,max} % .. functions \DeclareMathOperator{\id}{id} \DeclareMathOperator{\sign}{sign} \DeclareMathOperator{\err}{err} \DeclareMathOperator{\ReLU}{ReLU} % .. analysis \let\d\undefined \newcommand{\d}{\operatorname{d}\mathopen{}} \newcommand{\df}[2]{ {\f{\d #1}{\d #2}}} \newcommand{\ds}[2]{ {\s{\d #1}{\d #2}}} \newcommand{\part}{\partial} \newcommand{\partf}[2]{\f{\part #1}{\part #2}} \newcommand{\parts}[2]{\s{\part #1}{\part #2}} \newcommand{\grad}[1]{\mathop{\nabla\!_{#1}}} % .. sets of numbers \newcommand{\N}{\mathbb{N}} \newcommand{\Z}{\mathbb{Z}} \newcommand{\R}{\mathbb{R}} \newcommand{\C}{\mathbb{C}} \newcommand{\F}{\mathbb{F}} % %%% SPECIALIZED MATH %%% % % Logic \renewcommand{\and}{\wedge} \newcommand{\AND}{\bigwedge} \newcommand{\or}{\vee} \newcommand{\OR}{\bigvee} \newcommand{\xor}{\oplus} \newcommand{\XOR}{\bigoplus} \newcommand{\union}{\cup} \newcommand{\inter}{\cap} \newcommand{\UNION}{\bigcup} \newcommand{\INTER}{\bigcap} \newcommand{\comp}{\overline} \newcommand{\true}{\r{true}} \newcommand{\false}{\r{false}} \newcommand{\tf}{\set{\true,\false}} \DeclareMathOperator{\One}{\mathbb{1}} \DeclareMathOperator{\1}{\mathbb{1}} % % Linear algebra \renewcommand{\span}{\mathrm{span}} \DeclareMathOperator{\rank}{rank} \DeclareMathOperator{\proj}{proj} \DeclareMathOperator{\dom}{dom} \DeclareMathOperator{\Img}{Im} \newcommand{\transp}{\mathsf{T}} \renewcommand{\t}{^\transp} % ... named tensors \newcommand{\namedtensorstrut}{\vphantom{fg}} % milder than \mathstrut \newcommand{\name}[1]{\mathsf{\namedtensorstrut #1}} \newcommand{\nbin}[2]{\mathbin{\underset{\substack{#1}}{\namedtensorstrut #2}}} \newcommand{\ndot}[1]{\nbin{#1}{\odot}} \newcommand{\ncat}[1]{\nbin{#1}{\oplus}} \newcommand{\nsum}[1]{\sum\limits_{\substack{#1}}} \newcommand{\nfun}[2]{\mathop{\underset{\substack{#1}}{\namedtensorstrut\mathrm{#2}}}} \newcommand{\ndef}[2]{\newcommand{#1}{\name{#2}}} \newcommand{\nt}[1]{^{\transp(#1)}} % % Probability \newcommand{\Normal}{\mathcal{N}} \let\Pr\undefined \DeclareMathOperator*{\Pr}{Pr} \DeclareMathOperator*{\G}{\mathbb{G}} \DeclareMathOperator*{\Odds}{Od} \DeclareMathOperator*{\E}{E} \DeclareMathOperator*{\Var}{Var} \DeclareMathOperator*{\Cov}{Cov} \DeclareMathOperator*{\corr}{corr} \DeclareMathOperator*{\median}{median} \newcommand{\dTV}{d_{\mathrm{TV}}} \newcommand{\dHel}{d_{\mathrm{Hel}}} \newcommand{\dJS}{d_{\mathrm{JS}}} % ... information theory \let\H\undefined \DeclareMathOperator*{\H}{H} \DeclareMathOperator*{\I}{I} \DeclareMathOperator*{\D}{D} % %%% SPECIALIZED COMPUTER SCIENCE %%% % % Complexity classes % .. classical \newcommand{\Poly}{\mathsf{P}} \newcommand{\NP}{\mathsf{NP}} \newcommand{\PH}{\mathsf{PH}} \newcommand{\PSPACE}{\mathsf{PSPACE}} \renewcommand{\L}{\mathsf{L}} % .. probabilistic \newcommand{\formost}{\mathsf{Я}} \newcommand{\RP}{\mathsf{RP}} \newcommand{\BPP}{\mathsf{BPP}} \newcommand{\MA}{\mathsf{MA}} \newcommand{\AM}{\mathsf{AM}} \newcommand{\IP}{\mathsf{IP}} \newcommand{\RL}{\mathsf{RL}} % .. circuits \newcommand{\NC}{\mathsf{NC}} \newcommand{\AC}{\mathsf{AC}} \newcommand{\ACC}{\mathsf{ACC}} \newcommand{\TC}{\mathsf{TC}} \newcommand{\Ppoly}{\mathsf{P}/\poly} \newcommand{\Lpoly}{\mathsf{L}/\poly} % .. resources \newcommand{\TIME}{\mathsf{TIME}} \newcommand{\SPACE}{\mathsf{SPACE}} \newcommand{\TISP}{\mathsf{TISP}} \newcommand{\SIZE}{\mathsf{SIZE}} % .. keywords \newcommand{\co}{\mathsf{co}} \newcommand{\Prom}{\mathsf{Promise}} % % Boolean analysis \newcommand{\zo}{\set{0,1}} \newcommand{\pmo}{\set{\pm 1}} \newcommand{\zpmo}{\set{0,\pm 1}} \newcommand{\harpoon}{\!\upharpoonright\!} \newcommand{\rr}[2]{#1\harpoon_{#2}} \newcommand{\Fou}[1]{\widehat{#1}} \DeclareMathOperator{\Ind}{\mathrm{Ind}} \DeclareMathOperator{\Inf}{\mathrm{Inf}} \newcommand{\Der}[1]{\operatorname{D}_{#1}\mathopen{}} \newcommand{\Exp}[1]{\operatorname{E}_{#1}\mathopen{}} \DeclareMathOperator{\Stab}{\mathrm{Stab}} \DeclareMathOperator{\T}{T} \DeclareMathOperator{\sens}{\mathrm{s}} \DeclareMathOperator{\bsens}{\mathrm{bs}} \DeclareMathOperator{\fbsens}{\mathrm{fbs}} \DeclareMathOperator{\Cert}{\mathrm{C}} \DeclareMathOperator{\DT}{\mathrm{DT}} \DeclareMathOperator{\CDT}{\mathrm{CDT}} % canonical \DeclareMathOperator{\ECDT}{\mathrm{ECDT}} \DeclareMathOperator{\CDTv}{\mathrm{CDT_{vars}}} \DeclareMathOperator{\ECDTv}{\mathrm{ECDT_{vars}}} \DeclareMathOperator{\CDTt}{\mathrm{CDT_{terms}}} \DeclareMathOperator{\ECDTt}{\mathrm{ECDT_{terms}}} \DeclareMathOperator{\CDTw}{\mathrm{CDT_{weighted}}} \DeclareMathOperator{\ECDTw}{\mathrm{ECDT_{weighted}}} \DeclareMathOperator{\AvgDT}{\mathrm{AvgDT}} \DeclareMathOperator{\PDT}{\mathrm{PDT}} % partial decision tree \DeclareMathOperator{\DTsize}{\mathrm{DT_{size}}} \DeclareMathOperator{\W}{\mathbf{W}} % .. functions (small caps sadly doesn't work) \DeclareMathOperator{\Par}{\mathrm{Par}} \DeclareMathOperator{\Maj}{\mathrm{Maj}} \DeclareMathOperator{\HW}{\mathrm{HW}} \DeclareMathOperator{\Thr}{\mathrm{Thr}} \DeclareMathOperator{\Tribes}{\mathrm{Tribes}} \DeclareMathOperator{\RotTribes}{\mathrm{RotTribes}} \DeclareMathOperator{\CycleRun}{\mathrm{CycleRun}} \DeclareMathOperator{\SAT}{\mathrm{SAT}} \DeclareMathOperator{\UniqueSAT}{\mathrm{UniqueSAT}} % % Dynamic optimality \newcommand{\OPT}{\mathsf{OPT}} \newcommand{\Alt}{\mathsf{Alt}} \newcommand{\Funnel}{\mathsf{Funnel}} % % Alignment \DeclareMathOperator{\Amp}{\mathrm{Amp}} % %%% TYPESETTING %%% % % In text \renewcommand{\th}{^{\mathrm{th}}} \newcommand{\degree}{^\circ} % % Fonts % .. bold \newcommand{\BA}{\boldsymbol{A}} \newcommand{\BB}{\boldsymbol{B}} \newcommand{\BC}{\boldsymbol{C}} \newcommand{\BD}{\boldsymbol{D}} \newcommand{\BE}{\boldsymbol{E}} \newcommand{\BF}{\boldsymbol{F}} \newcommand{\BG}{\boldsymbol{G}} \newcommand{\BH}{\boldsymbol{H}} \newcommand{\BI}{\boldsymbol{I}} \newcommand{\BJ}{\boldsymbol{J}} \newcommand{\BK}{\boldsymbol{K}} \newcommand{\BL}{\boldsymbol{L}} \newcommand{\BM}{\boldsymbol{M}} \newcommand{\BN}{\boldsymbol{N}} \newcommand{\BO}{\boldsymbol{O}} \newcommand{\BP}{\boldsymbol{P}} \newcommand{\BQ}{\boldsymbol{Q}} \newcommand{\BR}{\boldsymbol{R}} \newcommand{\BS}{\boldsymbol{S}} \newcommand{\BT}{\boldsymbol{T}} \newcommand{\BU}{\boldsymbol{U}} \newcommand{\BV}{\boldsymbol{V}} \newcommand{\BW}{\boldsymbol{W}} \newcommand{\BX}{\boldsymbol{X}} \newcommand{\BY}{\boldsymbol{Y}} \newcommand{\BZ}{\boldsymbol{Z}} \newcommand{\Ba}{\boldsymbol{a}} \newcommand{\Bb}{\boldsymbol{b}} \newcommand{\Bc}{\boldsymbol{c}} \newcommand{\Bd}{\boldsymbol{d}} \newcommand{\Be}{\boldsymbol{e}} \newcommand{\Bf}{\boldsymbol{f}} \newcommand{\Bg}{\boldsymbol{g}} \newcommand{\Bh}{\boldsymbol{h}} \newcommand{\Bi}{\boldsymbol{i}} \newcommand{\Bj}{\boldsymbol{j}} \newcommand{\Bk}{\boldsymbol{k}} \newcommand{\Bp}{\boldsymbol{p}} \newcommand{\Bq}{\boldsymbol{q}} \newcommand{\Br}{\boldsymbol{r}} \newcommand{\Bs}{\boldsymbol{s}} \newcommand{\Bt}{\boldsymbol{t}} \newcommand{\Bu}{\boldsymbol{u}} \newcommand{\Bv}{\boldsymbol{v}} \newcommand{\Bw}{\boldsymbol{w}} \newcommand{\Bx}{\boldsymbol{x}} \newcommand{\By}{\boldsymbol{y}} \newcommand{\Bz}{\boldsymbol{z}} \newcommand{\Balpha}{\boldsymbol{\alpha}} \newcommand{\Bbeta}{\boldsymbol{\beta}} \newcommand{\Bgamma}{\boldsymbol{\gamma}} \newcommand{\Bdelta}{\boldsymbol{\delta}} \newcommand{\Beps}{\boldsymbol{\eps}} \newcommand{\Bveps}{\boldsymbol{\veps}} \newcommand{\Bzeta}{\boldsymbol{\zeta}} \newcommand{\Beta}{\boldsymbol{\eta}} \newcommand{\Btheta}{\boldsymbol{\theta}} \newcommand{\Biota}{\boldsymbol{\iota}} \newcommand{\Bkappa}{\boldsymbol{\kappa}} \newcommand{\Blambda}{\boldsymbol{\lambda}} \newcommand{\Bmu}{\boldsymbol{\mu}} \newcommand{\Bnu}{\boldsymbol{\nu}} \newcommand{\Bxi}{\boldsymbol{\xi}} \newcommand{\Bomicron}{\boldsymbol{\omicron}} \newcommand{\Bpi}{\boldsymbol{\pi}} \newcommand{\Brho}{\boldsymbol{\rho}} \newcommand{\Bsigma}{\boldsymbol{\sigma}} \newcommand{\Btau}{\boldsymbol{\tau}} \newcommand{\Bupsilon}{\boldsymbol{\upsilon}} \newcommand{\Bphi}{\boldsymbol{\phi}} \newcommand{\Bfi}{\boldsymbol{\fi}} \newcommand{\Bchi}{\boldsymbol{\chi}} \newcommand{\Bpsi}{\boldsymbol{\psi}} \newcommand{\Bomega}{\boldsymbol{\omega}} % .. calligraphic \newcommand{\CA}{\mathcal{A}} \newcommand{\CB}{\mathcal{B}} \newcommand{\CC}{\mathcal{C}} \newcommand{\CD}{\mathcal{D}} \newcommand{\CE}{\mathcal{E}} \newcommand{\CF}{\mathcal{F}} \newcommand{\CG}{\mathcal{G}} \newcommand{\CH}{\mathcal{H}} \newcommand{\CI}{\mathcal{I}} \newcommand{\CJ}{\mathcal{J}} \newcommand{\CK}{\mathcal{K}} \newcommand{\CL}{\mathcal{L}} \newcommand{\CM}{\mathcal{M}} \newcommand{\CN}{\mathcal{N}} \newcommand{\CO}{\mathcal{O}} \newcommand{\CP}{\mathcal{P}} \newcommand{\CQ}{\mathcal{Q}} \newcommand{\CR}{\mathcal{R}} \newcommand{\CS}{\mathcal{S}} \newcommand{\CT}{\mathcal{T}} \newcommand{\CU}{\mathcal{U}} \newcommand{\CV}{\mathcal{V}} \newcommand{\CW}{\mathcal{W}} \newcommand{\CX}{\mathcal{X}} \newcommand{\CY}{\mathcal{Y}} \newcommand{\CZ}{\mathcal{Z}} % .. typewriter \newcommand{\TA}{\mathtt{A}} \newcommand{\TB}{\mathtt{B}} \newcommand{\TC}{\mathtt{C}} \newcommand{\TD}{\mathtt{D}} \newcommand{\TE}{\mathtt{E}} \newcommand{\TF}{\mathtt{F}} \newcommand{\TG}{\mathtt{G}} \newcommand{\TH}{\mathtt{H}} \newcommand{\TI}{\mathtt{I}} \newcommand{\TJ}{\mathtt{J}} \newcommand{\TK}{\mathtt{K}} \newcommand{\TL}{\mathtt{L}} \newcommand{\TM}{\mathtt{M}} \newcommand{\TN}{\mathtt{N}} \newcommand{\TO}{\mathtt{O}} \newcommand{\TP}{\mathtt{P}} \newcommand{\TQ}{\mathtt{Q}} \newcommand{\TR}{\mathtt{R}} \newcommand{\TS}{\mathtt{S}} \newcommand{\TT}{\mathtt{T}} \newcommand{\TU}{\mathtt{U}} \newcommand{\TV}{\mathtt{V}} \newcommand{\TW}{\mathtt{W}} \newcommand{\TX}{\mathtt{X}} \newcommand{\TY}{\mathtt{Y}} \newcommand{\TZ}{\mathtt{Z}}$

Personal summary of Attention Is All You Need by Vaswani, Shazeer, Parmar, Uszkoreit, Jones, Gomez, Kaiser, and Polosukhin, with a lot of inspiration from A Mathematical Framework for Transformer Circuits by Elhage, Nanda, Olsson et al. This is more of an opinion piece than a faithful description. Also it’s a work in progress; I’m probably misunderstanding a lot of things (even really basic and important things), and I don’t have a good sense of scale.

Simplified structure

Essentially transformers look like this:

#figure draw three gray question marks instead of one

Almost all of the components work are local to one token and work exactly the same way for each token. The only exception is the attention part, which moves information across tokens, and can choose to move different information depending on distance (this is illustrated by arrows of different colors).

Input and output

The model takes a sequence of tokens $\newcommand{\dtokens}{ {d_\mathrm{tokens}}}\p{t^{(1)}, \ldots, t^{(n)}} \in [\dtokens]$ (say, $\dtokens = 50000$), and outputs a guess for the next token $t^{(n+1)}$ in the form of a vector $\tau^{(n+1)} \in \R^{\dtokens}$, where the coordinates of $\tau^{(n+1)}$ are logits. Logits means log odds: $\tau^{(n+1)}_i$ is the log of the odds of the next token being $i$. That is, according to the model,

\[ \Pr\b{t^{(n+1)}=i} = \frac{\exp\p{\tau_i^{(n+1)}}}{\sum_{j=1}^\dtokens \exp\p{\tau_j^{(n+1)}}}. \]

Because they’re working additively on log odds, transformers can easily perform Bayesian reasoning, which is all about working multiplicatively on odds.

In fact, the model doesn’t give logits for the $(n+1)\th$ token only: it also makes a logit prediction $\tau^{(i+1)}$ based on each prefix $\p{t^{(1)}, \ldots, t^{(i)}}$ ($i \in [n]$) of the sequence, which describes its guess for the $(i+1)\th$ token based on seeing just the first $i$ tokens. This is somehow more natural because of how its structured internally, and allows it to be trained and (as far as I can tell1) run more quickly.

Encoding and decoding

The encoder transform the tokens $t^{(i)}$ into “state vectors” in a smaller dimensional space $\newcommand{\dstate}{ {d_\mathrm{state}}}\R^{\dstate}$ (say $\dstate = 512$), in a such way that spatial directions carry semantic information. For example, there could be directions correlating with

  • gender of a noun
  • tense of a verb
  • positive/negative feeling
  • words related to water vs dryness
  • etc.

As far as I can tell, it would have been possible to use state vectors in $\R^{\dtokens}$ (with tokens encoded one-hot), but that makes running/training the transformer much more expensive, while the encoding is very cheap to pre-train and doesn’t lose much relevant information.

These state vectors will be maintained in a residual stream (indicated by the thick black line). They will act as a sort of scratchpad that can be read from and written to, accumulating information and making inferences from the current token as well as past tokens, in order to figure out what the next token is.

The decoder takes the result from $\R^{\dstate}$ back into the logits $\tau^{(i+1)} \in \R^{\dtokens}$.

Layers

In between the encoder and the decoder, the state vectors in $\R^\dstate$ get transformed through a series of $N=6$ layers. Each layer has the same structure (but independent weights): an attention module followed by a perceptron module.

Attention module

The role of attention is to move information forward in time between states corresponding to different tokens, i.e. it moves information from left to right. Information never flows back in time.

Say that the states coming into the attention layer are $x^{(1)}, \ldots, x^{(n)} \in \R^{\dstate}$. Then each state $x^{(i)}$ will perform “queries” on the states $x^{(i-1)}$, $x^{(i-2)}$, etc, and when there is a “match”, it will copy some value from that state.

More precisely, for each state $x^{(i)}$ and each previous token’s state $x^{(i-k)}$, it will do several queries of the following form (called “attention heads”):

  • query: compute the product $p =\p{x^{(i-k)}}^\transp Q x^{(i)} \in \R$, for some fairly low-rank2 matrix $Q \in \R^{\dstate \times \dstate}$,
  • copy: if $p$ is large enough3, it will add $Ax^{(i-k)}$ to $x^{(i)}$, where $A \in \R^{\dstate \times \dstate}$ is another low-rank2 matrix.

The values of $Q$ and $A$ differ between queries, and they can also differ based on the distance $k$ between them within the sequence. But they don’t depend on $i$: attention is “translation-invariant”.

Perceptron module

The perceptron module is just a neural network with one hidden layer (with $\newcommand{\dff}{ {d_\mathrm{per}}}\dff=2048$ neurons) and ReLU activations. That is, for each state vector $x^{(i)}$, its output is

\[ y^{(i)} =W_2\ \ReLU(W_1x^{(i)}+b_1)+b_2 \]

with $W_1 \in \R^{\dff\times \dstate}$, $W_2 \in \R^{\dstate \times \dff}$, $b_1 \in \R^\dff$, and $b_2 \in \R^{\dstate}$. This output $y^{(i)}$ is then added back to $x^{(i)}$.

Modifications and omitted details

Positional encoding

#to-write

  • explain how the positional encodings allow specialization of the heads based on distance (or analysis purposes we might as well take this specialization as an explicit part of the architecture rather than something that has to be found by gradient descent)
  • how cursed tokens are and detokenization (reforming words / phrases that are split across several tokens)

Layer norm

#to-write In actuality, things get renormed, and since things get added over the layers, the earlier computations kind of “decay” in importance

Intuitions

Most of the intuitions are from the Neel Nanda ML street talk interview.

The role of attention

In previous models like convolutional networks, the architecture dictated how to combine information between positions (e.g. do this many convolutions with this kernel side, this stride, etc.). Instead, transformers devote a fraction of their parameters to figuring out how to move information, and let the model figure out what is most useful.

The residual stream is a bottleneck

The dimensions of the residual stream are a very precious resource.

  • In GPT-2 small, there are 50000 different tokens, and each MLP layer has 3000 layers, yet this must be compressed into just a 768-dimensional residual stream.
  • In the Othello paper, the model uses 128 out of 512 dimensions to represent the Othello board, but probably only happens because the problem is set up make representing the game board insanely important. In normal situations, transformers may rely on other tricks, like using attention heads, to avoid using up residual dimensions (e.g. GPT-4 seems to be able to play semi-valid chess by just attending to previous moves).
  • You can’t (say) double the size of the residual stream without doubling the number of parameters overall, since it doubles the size of the matrices you need for computing the queries/keys/values in each head, and it doubles the number of weights going into each neuron in the feed-forward layer.

Depth vs width trade-offs

Most computations don’t need all $2N$ layers of nonlinearities and can just sit in the residual stream for a while, so it doesn’t actually change that much if you e.g. halve the number of layers but double the number of attention heads / feed-forward neurons in each layer. The total number of parameters is what matters the most.

Proportions

#to-do

  • make a table of sizes / neuron counts / weight counts estimates for the original transformer, GPT-2, GPT-3, etc.
  • tie in with scaling laws
  1. Since information only goes forward in time, in order to produce the next token, you only need to perform the computations in the “column” of the model that corresponds to the last token, having already computed all the relevant attention keys and values from previous tokens in all layers. But perhaps this is not of much use since it doesn’t reduce the required sequential time, and running the model is much less expensive than training it anyway. 

  2. I’m not sure whether it’s crucial that the matrix is low rank, or it could have been full rank but that would have been more expensive unnecessarily.  2

  3. In practice, this is gradual and depends on how large $p$ is compared to its value in other heads. I’m not sure whether this dependence between heads is necessary (is the fact that only $\sim 1$ head gets activated at a time really crucial?). Maybe it’s only useful so that the coefficients in front of $Bx^{(i-k)}$ don’t get crazy.