Link to original article
Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: Explaining grokking through circuit efficiency, published by Vikrant Varma on September 8, 2023 on The AI Alignment Forum.
This is a linkpost for our paper ''Explaining grokking through circuit efficiency", which provides a general theory explaining when and why grokking (aka delayed generalisation) occurs, and makes several interesting and novel predictions which we experimentally confirm (introduction copied below). You might also enjoy our explainer on X/Twitter.
Abstract
One of the most surprising puzzles in neural network generalisation is grokking: a network with perfect training accuracy but poor generalisation will, upon further training, transition to perfect generalisation. We propose that grokking occurs when the task admits a generalising solution and a memorising solution, where the generalising solution is slower to learn but more efficient, producing larger logits with the same parameter norm. We hypothesise that memorising circuits become more inefficient with larger training datasets while generalising circuits do not, suggesting there is a critical dataset size at which memorisation and generalisation are equally efficient. We make and confirm four novel predictions about grokking, providing significant evidence in favour of our explanation.
Most strikingly, we demonstrate two novel and surprising behaviours: ungrokking, in which a network regresses from perfect to low test accuracy, and semi-grokking, in which a network shows delayed generalisation to partial rather than perfect test accuracy.
Introduction
When training a neural network, we expect that once training loss converges to a low value, the network will no longer change much. Power et al. (2021) discovered a phenomenon dubbed grokking that drastically violates this expectation. The network first ''memorises'' the data, achieving low and stable training loss with poor generalisation, but with further training transitions to perfect generalisation. We are left with the question: why does the network's test performance improve dramatically upon continued training, having already achieved nearly perfect training performance?
Recent answers to this question vary widely, including the difficulty of representation learning (Liu et al., 2022), the scale of parameters at initialisation (Liu et al., 2023), spikes in loss (''slingshots") (Thilak et al., 2022), random walks among optimal solutions (Millidge et al., 2022), and the simplicity of the generalising solution (Nanda et al., 2023, Appendix E). In this paper, we argue that the last explanation is correct, by stating a specific theory in this genre, deriving novel predictions from the theory, and confirming the predictions empirically.
We analyse the interplay between the internal mechanisms that the neural network uses to calculate the outputs, which we loosely call ''circuits'' (Olah et al., 2020). We hypothesise that there are two families of circuits that both achieve good training performance: one which generalises well (Cgen) and one which memorises the training dataset (Cmem). The key insight is that when there are multiple circuits that achieve strong training performance, weight decay prefers circuits with high ''efficiency'', that is, circuits that require less parameter norm to produce a given logit value.
Efficiency answers our question above: if Cgen is more efficient than Cmem, gradient descent can reduce nearly perfect training loss even further by strengthening Cgen while weakening Cmem, which then leads to a transition in test performance. With this understanding, we demonstrate in Section 3 that three key properties are sufficient for grokking: (1) Cgen generalises well while Cmem does not, (2) Cgen is more efficient than Cmem, and (3) Cgen is learned more slowly than Cmem.
Since Cgen generalises well, it automatically works ...