The Nonlinear Library: Alignment Forum

AF - The positional embedding matrix and previous-token heads: how do they actually work? by Adam Yedidia


Listen Later

Link to original article

Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: The positional embedding matrix and previous-token heads: how do they actually work?, published by Adam Yedidia on August 10, 2023 on The AI Alignment Forum.
tl;dr: This post starts with a mystery about positional embeddings in GPT2-small, and from there explains how they relate to previous-token heads, i.e. attention heads whose role is to attend to the previous token. I tried to make the post relatively accessible even if you're not already very familiar with concepts from LLM interpretability.
Introduction: A mystery about previous-token heads
In the context of transformer language models, a previous-token head is an attention head that attends primarily or exclusively to the previous token. (Attention heads are the parts of the network that determine, at each token position, which previous token position to read information from. They're the way transformers move information from one part of the prompt to another, and in GPT2-small there are 12 of them at each of the 12 layers, for a total of 144 heads.)
Previous-token heads are performing a really basic function in a transformer model: for each token, figuring out what the token that precedes it is, and copying information from that preceding token to the token after it. This is key piece of almost any text-based task you can imagine - it's hard to read a sentence if you can't tell which words in the sentence come before which other words, or which sub-words in a word come first. But how do these previous-token heads actually work? How do they know which token comes previously?
The easy answer to this is "positional embeddings" - for each possible token position in the prompt, at that token position, there are certain directions added to the residual stream that encode the "meaning" of that token position in vector space. This is a confusing concept - basically there's a vector that means "this is the first token in the prompt" and another vector that means "this is the second token in the prompt" and so on. In GPT2-small, these vectors are learned, there are 1024 of them, and they form a helix.
Above: the positional embeddings of GPT2-small, scatter-plotted along its most important three PCA directions.
So if the network wanted to form a previous-token head, it could, for a given token position, look up which these 1024 vectors corresponds to the current token position, figure out what the previous token position would be, and attend to the token with positional embedding corresponding to the previous token position. At least in principle.
In practice, though, that does not seem to be what GPT2-small is doing. We can easily verify this, by (for example) modifying the positional embedding matrix by averaging its rows together in groups of 5 (or 10, or 20 - performance starts to break down around 50 or 100). In other words, token positions k through k+5 have identical positional embeddings. When we do that, GPT2's output is still pretty much unchanged!
Red: The same helix of GPT2-small positional embeddings as above, zoomed-in. Blue: the averaged-together positional embeddings, averaged in groups of 10. Each blue dot represents 10 positional embeddings.
For instance, with a prompt like:
the model is 99+% sure that apple or a variant of apple should come next. When we put orange apple in the last quote instead, it's 99+% sure that orange should come next. This remains true even when we doctor the positional embedding matrix as previously described. If the model was relying exclusively on positional embeddings to figure out where each token was, it should have been unsure about which of apple or orange would come next, since it wouldn't know where exactly in the sentence each of them were.
Zooming in on some actual previous-token heads, we can also see that the main previous-token heads in the network (L2H2, L3H3, ...
...more
View all episodesView all episodes
Download on the App Store

The Nonlinear Library: Alignment ForumBy The Nonlinear Fund


More shows like The Nonlinear Library: Alignment Forum

View all
AXRP - the AI X-risk Research Podcast by Daniel Filan

AXRP - the AI X-risk Research Podcast

9 Listeners