Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: resolving some neural network mysteries, published by bhauth on June 19, 2023 on LessWrong.
Here are some things about neural networks that I used to find puzzling but now feel that I have adequate explanations for. The theory behind these answers didn't start to be understood until well after the correct things to do were found by chance or blind imitation of brains.
Why is good optimization possible?
Neural networks typically deal with "non-convex" optimization problems. Traditionally, using gradient descent for those was considered impractical, because it would rapidly get stuck in local minima. That was part of the motivation for evolutionary approaches.
Why, then, are neural networks trainable by gradient descent? Because if you add enough extra dimensions, non-convex problems become convex. Empirically, with massive overparameterization, the energy landscape tends to have many saddle points but few local minima. Showing theoretical convergence guarantees for overparameterized networks is a recent and ongoing research topic; see eg this.
As I previously noted, this is why sparse networks from iterative magnitude pruning have good performance, but sparse networks generally can't be trained from scratch as well as dense networks.
This also explains some "thresholds" of neural network performance vs size: when overparameterization is proportional to problem non-convexity, good training becomes possible and performance improves significantly.
Why is generalization possible?
Adding enough free variables can turn non-convex problems into convex ones. Why didn't people just do that in the past, then? Because far before you get to that point, the extra free variables led to overfitting that reduced test performance. People tried simple regularization like neural networks use, and that was completely inadequate.
Overparameterized neural networks can learn random data. Why do networks with fairly simple regularization tend to generalize?
Distance in neural network latent spaces being meaningful is basically the main useful thing about neural networks. Another phrasing of the above question is: Why is distance in latent spaces meaningful for latent space points not in the training set?
A few years back, some people noticed that neural network activation functions have spectral bias. With the types of activation functions used, low-frequency relationships are fit more quickly than high-frequency ones. That causes latent space relationships to be preferentially fit in such a way that point distances are related to point similarity. This can then be tuned by simple regularization: if you have spectral bias and balance learning rate vs regularization globally, you can control the frequency range learned.
An obvious way to test this theory of neural network generalization is to find some activation functions with relatively low spectral bias, and see how they perform. This paper tries a "hat" activation function, and finds that loss on the training set goes down much faster but test accuracy is much worse. This paper does some relevant tests on spectral bias.
It's known that there is no universal best activation function. The optimal choice varies with:
different problem types
regularization settings
layer depth
I think using different activation functions for different depths is a semi-common technique at large AI labs now. This spectral bias framework can explain variations in relative performance of activation functions as spectral bias matching.
Why not mixed activation functions?
There are some reasons to think mixing activation functions in the same layer would be better:
Neural networks have many equivalent permutations of their variables. By mixing different activation functions in the same later, fewer permutations would be equivalent, which increases expressive power....