The Nonlinear Library: Alignment Forum

AF - You can remove GPT2's LayerNorm by fine-tuning for an hour by Stefan Heimersheim


Listen Later

Link to original article

Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: You can remove GPT2's LayerNorm by fine-tuning for an hour, published by Stefan Heimersheim on August 8, 2024 on The AI Alignment Forum.
This work was produced at Apollo Research, based on initial research done at MATS.
LayerNorm is annoying for mechanstic interpretability research ("[...] reason #78 for why interpretability researchers hate LayerNorm" - Anthropic, 2023).
Here's a Hugging Face link to a GPT2-small model without any LayerNorm.
The final model is only slightly worse than a GPT2 with LayerNorm[1]:
Dataset
Original GPT2
Fine-tuned GPT2 with LayerNorm
Fine-tuned GPT without LayerNorm
OpenWebText (ce_loss)
3.095
2.989
3.014 (+0.025)
ThePile (ce_loss)
2.856
2.880
2.926 (+0.046)
HellaSwag (accuracy)
29.56%
29.82%
29.54%
I fine-tuned GPT2-small on OpenWebText while slowly removing its LayerNorm layers, waiting for the loss to go back down after reach removal:
Introduction
LayerNorm (LN) is a component in Transformer models that normalizes embedding vectors to have constant length; specifically it divides the embeddings by their standard deviation taken over the hidden dimension. It was originally introduced to stabilize and speed up training of models (as a replacement for batch normalization). It is active during training and inference.
The equation includes the standard deviation (std) Var[x]+ϵ which makes it a non-linear operation. This hinders interpretability in a variety of ways, from annoyances and inaccuracies such as
attributing residual stream directions to logit effects (e.g. SAE features, direct logit attribution),[2]
being annoying to deal with Attribution Patching, or
being difficult to deal with in Apollo's LIB method.
In the Docstring circuit analysis we seriously considered whether the model might be using LN in its algorithm. This post even shows that LN can be used as the sole non-linearity to solve non-linear classification problems (see also this related work).
Recently, with progress in Sparse Dictionary Learning, agendas (e.g. this one) imagine decomposing networks into sets of sparsely connected components (SAEs, Transcoders, etc.). A core difficulty to "putting it all together" is that the interactions between different components often route through LayerNorm whose effect we do not understand.
Motivation
It would be pretty neat to have an LLM that still works (speaks English etc.) while less or no LN layers. One option would be to train a model without LN from scratch (done for tiny models, e.g. TinyModel), but this is very hard or impossible for larger models (hearsay is that you need a low learning rate and to be very careful).
Taking an existing model and removing the LN layers however seems doable if LN isn't implementing some important computation.[3] That is, LN "does its thing" and the model has learned to "deal with it", but it's not irreplaceable. A reason to be optimistic is that the spread of standard deviations across different samples isn't that large, so maybe replacing the LN-computed standard deviation with a fixed number might kinda work.
Method
I take GPT2-small, fine-tune it on OpenWebText, and remove LNs one-by-one while fine-tuning.
The only non-linear operation in a LN layer is the division by the standard deviation (std) of the embedding vectors; the remaining operations can be absorbed into later weight matrices (see the
fold_ln option in TransformerLens; also discussed in this appendix). Thus I mainly focus on the std part here.
My general strategy is to "remove" an LN layer (this makes the loss go up), and then to train the model for some time (on the original training data) until the loss is back near the baseline. For this "remove" step I do the following
Calculate the average std on the dataset (I used a quite small sample, 16 prompts), separately for position 0 and position > 0
Replace the std calculatio...
...more
View all episodesView all episodes
Download on the App Store

The Nonlinear Library: Alignment ForumBy The Nonlinear Fund


More shows like The Nonlinear Library: Alignment Forum

View all
AXRP - the AI X-risk Research Podcast by Daniel Filan

AXRP - the AI X-risk Research Podcast

8 Listeners