Link to original article
Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: Finding Sparse Linear Connections between Features in LLMs, published by Logan Riggs Smith on December 9, 2023 on The AI Alignment Forum.
TL;DR: We use SGD to find sparse connections between features; additionally a large fraction of features between the residual stream & MLP can be modeled as linearly computed despite the non-linearity in the MLP. See linear feature section for examples.
Special thanks to fellow AISST member, Adam Kaufman, who originally thought of the idea of learning sparse connections between features & to Jannik Brinkmann for training these SAE's.
Sparse AutoEncoders (SAE)'s are able to turn the activations of an LLM into interpretable features. To define circuits, we would like to find how these features connect to each other. SAE's allowed us to scalably find interpretable features using SGD, so why not use SGD to find the connections too?
We have a set of features before the MLP, F1, and a set of features after the MLP, F2. These features were learned by training SAE's on the activations at these layers.
Ideally, we learn a linear function such that F2 = W(F1), & W is sparse (ie L1 penalty on weights of W). So then we can look at a feature in F2, and say "Oh, it's just a sparse linear combination of features of F1 e.g. 0.8*(however feature) + 0.6*(but feature)", which would be quite interpretable!
However, we're trying to replicate an MLP's computation, which surely can't be all linear![1] So, what's the simplest computation from F1 to F2 that gets the lowest loss (ignoring L1 weight sparsity penalty for now)?
Training on only MSE between F1 & F2, we plot the MSE throughout training across 5 layers in Pythia-70m-deduped in 4 settings:
Linear: y=Wx
Nonlinear: y=Relu(Wx)
MLP: y=W2ReLU(W1x)
Two Nonlinear: ReLU(W2ReLU(W1x))
For all layers, training loss clusters along (MLP & two nonlinear) and (linear & nonlinear). Since MLP & linear are the simplest of these two clusters, the rest of the analysis will only look at those two.
[I also looked at bias vs no-bias: adding a bias didn't positively improve loss, so it was excluded]
Interestingly enough, the relative linear-MLP difference is huge in the last layer (and layer 2). The last layer is also much larger loss in general, though the L2 norm of the MLP activations in layer 5 are 52 compared to 13 in layer 4. This is a 4x increase, which would be a 16x increase in MSE loss. The losses for the last datapoints are 0.059 & 0.0038, which are ~16x different.
What percentage of Features are Linear?
Clearly the MLP is better, but that's on average. What if a percentage of features can be modeled as linearly computed? So we take the difference in loss for features (ie for a feature, we take linear loss - MLP loss), normalize all losses by their respective L2-norm/layer, and plot them.
Uhhh… there are some huge outliers here, meaning these specific features are very non-linear. Just setting a threshold of 0.001 for all layers:
Layer
Percent of features < 0.001 loss-difference (ie can be represented linearly)
1
78%
2
96%
3
97%
4
98%
5
99.1%
Most of the features can be linearly modeled w/ a small difference in loss (some have a negative loss-diff, meaning linear had a *lower* loss than the MLP. The values are so small that I'd chalk that up to noise). That's very convenient!
[Note: 0.001 is sort of arbitrary. To make this more principled, we could plot the effect of adding varying levels of noise to a layer of an LLM's activation, then pick a threshold that has a negligible drop in cross entropy loss?
Adding in Sparsity
Now, let's train sparse MLP & sparse linear connections. Additionally, we can restrict the linear one to only features that are well-modeled as linear (same w/ the MLP). We'll use the loss of:
Loss = MSE(F2 - F2_hat) + l1_alpha*L1(weights)
But how do we select l1_alpha? Let's just plot the ...