Audio note: this article contains 48 uses of latex notation, so the narration may be difficult to follow. There's a link to the original text in the episode description.
I've been researching tensor networks as a more interpretable architecture, but whenever I tell people this, they always ask "But is it any good?"
So I trained multiple 500M parameter LLMs on fineweb, showing the tensor variant needed ~4% more batches of data to match CE-loss.
There's a few caveats, so my personal estimate is around 15% worst to 10% better. Details below.
The Architecture
Replacing MLP w/ a Bilinear Layer
An MLP is a linear encoder, ReLU, then linear decoder.
_MLP(x) = D(ReLU(E(x)))_
A bilinear layer asks "what's better than one encoder? Two!"
_Bilinear(x) = D(Lx odot Rx)_
Where _odot_ means "element-wise multiply" eg
_[1, 2, 3] odot [1, 2, 3] = [1, 4, 9]_
A SwiGLU Layer (Swish Gated Linear Unit) says "Let's add in nonlinearities"
_SwiGLU(x) = D(swish(Lx) odot Rx)_
SwiGLU is a SOTA architecture & Bilinear is a tensor network.
Replacing Softmax Attn w/ Bilinear Attn
For a tensor network, we are only allowed polynomial nonlinearities. For attention, this means we need to replace softmax w/ [...]
---
Outline:
(00:48) The Architecture
(00:51) Replacing MLP w/ a Bilinear Layer
(01:45) Replacing Softmax Attn w/ Bilinear Attn
(02:24) Experiment & Results
(03:48) Caveats:
(03:52) (1) Softmax attention ran faster cause it has a CUDA kernel
(04:19) (2) Bilinear Attention can run much faster than Softmax Attn
(05:20) (3) Bilinear Attention has more Parameters
(05:52) (4) This was the 2nd-Dumbest Tensor-Attn Variant
(06:19) Replication & Trained Models
(06:31) Future Work
(06:56) Path to Impact
(07:59) Interp w/ Tensor Networks
(10:29) Appendix A: Noam Shazeers 2020 paper:
(11:03) Appendix B: Scaling of Bilinear Attention
(13:46) Appendix C: Bilinear Attention Expressivity
(14:22) Appendix D: But what about Flash Attention?
The original text contained 2 footnotes which were omitted from this narration.
---