Link to original article
Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: Improving SAE's by Sqrt()-ing L1 & Removing Lowest Activating Features, published by Logan Riggs Smith on March 15, 2024 on The AI Alignment Forum.
TL;DR
We achieve better SAE performance by:
Removing the lowest activating features
Replacing the L1(feature_activations) penalty function with L1(sqrt(feature_activations))
with 'better' meaning: we can reconstruct the original LLM activations w/ lower MSE & with fewer features/datapoint.
As a sneak peak (the graph should make more sense as we build up to it, don't worry!):
Now in more details:
Sparse Autoencoders (SAEs) reconstruct each datapoint in [layer 3's residual stream activations of Pythia-70M-deduped] using a certain amount of features (this is the L0-norm of the hidden activation in the SAE). Typically the higher activations are interpretable & the lowest of activations non-interpretable.
Here is a feature that activates mostly on apostrophe (removing it also makes it worse at predicting "s"). The lower activations are conceptually similar, but then we have a huge amount of tokens that are something else.
From a datapoint viewpoint, there's a similar story: given a specific datapoint, the top activation features make a lot of sense, but the lowest ones don't (ie if 20 features activate that reconstruct a specific datapoint, the top ~5 features make a decent amount of sense & the lower 15 make less and less sense)
Are these low-activating features actually important for downstream performance (eg CE)? Or are they modeling noise in the underlying LLM (which is why we see conceptually similar datapoints in lower activation points)?
Ablating Lowest Features
There are a few different ways to remove the "lowest" feature activations.
Dataset View:
Lowest k-features per datapoint
Feature View: Features have different activation values. Some are an OOM larger than others on average, so we can set feature specific thresholds.
Percentage of max activation - remove all feature activations that are < [10%] of max activation for that feature
Quantile - Remove all features in the [10th] percentile activations for each feature
Global Threshold - Let's treat all features the same. Set all feature activations less than [0.1] to 0.
It turns out that the simple global threshold performs the best:
[Note: "CE" refers to the CE when you replace [layer 3 residual stream]'s activations with the reconstruction from the SAE. Ultimately we want the original model's CE with the smallest amount of feature's per datapoint (L0 norm).]
You can halve the L0 w/ a small (~0.08) increase in CE. Sadly, there is an increase in both MSE & CE. If MSE was higher & CE stayed the same, then that supports the hypothesis that the SAE is modeling noise at lower activations (ie noise that's important for MSE/reconstruction but not for CE/downstream performance). But these lower activations are important for both MSE & CE similarly.
For completion sake, here's a messy graph w/ all 4 methods:
[Note: this was run on a different SAE than the other images]
There may be a more sophisticated methods that take into account feature-information (such as whether it's an outlier feature or feature frequency), but we'll be sticking w/ the global threshold for the rest of the post.
Sweeping Across SAE's with Different L0's
You can get widly different L0's by just sweeping the weight on the L1 penalty term where increasing the L0 increases reconstruction but at the cost of more, potentially polysemantic, features per datapoint. Does the above phenomona extend to SAE's w/ different L0's?
Looks like it does & the models seems to follow a pareto frontier.
Using L1(sqrt(feature_activation))
@Lucia Quirke trained SAE's with L1(sqrt(feature_activations)) (this punishes smaller activations more & larger activations less) and anecdotally noticed less of these smaller, unintepreta...