The Nonlinear Library: Alignment Forum

AF - Fact Finding: Do Early Layers Specialise in Local Processing? (Post 5) by Neel Nanda


Listen Later

Link to original article

Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: Fact Finding: Do Early Layers Specialise in Local Processing? (Post 5), published by Neel Nanda on December 23, 2023 on The AI Alignment Forum.
This is the fifth post in the Google DeepMind mechanistic interpretability team's investigation into how language models recall facts. This post is a bit tangential to the main sequence, and documents some interesting observations about how, in general, early layers of models somewhat (but not fully) specialise into processing recent tokens. You don't need to believe these results to believe our overall results about facts, but we hope they're interesting! And likewise you don't need to read the rest of the sequence to engage with this.
Introduction
In this sequence we've presented the multi-token embedding hypothesis, that a crucial mechanism behind factual recall is that on the final token of a multi-token entity there forms an "embedding", with linear representations of attributes of that entity. We further noticed that this seemed to be most of what early layers did, and that they didn't seem to respond much to prior context (e.g. adding "Mr Michael Jordan" didn't substantially change the residual).
We hypothesised the stronger claim that early layers (e.g. the first 10-20%), in general, specialise in local processing, and that the prior context (e.g. more than 10 tokens back) is only brought in in early-mid layers.
We note that this is stronger than the multi-token embedding hypothesis in two ways: it's a statement about how early layers behave on all tokens, not just the final tokens of entities about which facts are known; and it's a claim that early layers are not also doing longer range stuff in addition to producing the multi-token embedding (e.g. detecting the language of the text). We find this stronger hypothesis plausible, because tokens are a pretty messy input format, and analysing individual tokens in isolation can be highly misleading, e.g.
We tested this by taking a bunch of arbitrary prompts from the Pile, taking residual streams on those, truncating the prompts to the most recent few tokens and taking residual streams on the truncated prompts, and looking at the mean centred cosine sim at different layers.
Our findings:
Early layers do, in general, specialise in local processing, but it's a soft division of labour not a hard split.
There's a gradual transition where more context is brought in across the layers.
Early layers do significant processing on recent tokens, not just the current token - this is not just a trivial result where the residual stream is dominated by the current token and slightly adjusted by each layer
Early layers do much more long-range processing on common tokens (punctuation, articles, pronouns, etc)
Experiments
The "early layers specialise in local processing" hypothesis concretely predicts that, for a given token X in a long prompt, if we truncate the prompt to just the most recent few tokens before X, the residual stream at X should be very similar at early layers and dissimilar at later layers. We can test this empirically by looking at the cosine sim of the original vs truncated residual streams, as a function of layer and truncated context length. Taking cosine sims of residual streams naively can be misleading, as there's often a significant shared mean across all tokens, so we first subtract the mean residual stream across all tokens, and then take the cosine sim.
Set-Up
Model: Pythia 2.8B, as in the rest of our investigation
Dataset: Strings from the Pile, the Pythia pre-training distribution.
Metric: To measure how similar the original and truncated residual streams are we subtract the mean residual stream and then take the cosine sim.
We compute a separate mean per layer, across all tokens in random prompts from the Pile
Truncated context: We vary the number of tokens i...
...more
View all episodesView all episodes
Download on the App Store

The Nonlinear Library: Alignment ForumBy The Nonlinear Fund


More shows like The Nonlinear Library: Alignment Forum

View all
AXRP - the AI X-risk Research Podcast by Daniel Filan

AXRP - the AI X-risk Research Podcast

9 Listeners