Learning GenAI via SOTA Papers

EP041: FlashAttention Smashes the AI Memory Wall


Listen Later

The paper introduces FlashAttention, a new algorithm designed to address the slow and memory-intensive nature of Transformer models when processing long sequences. Standard self-attention has a time and memory complexity that scales quadratically with sequence length, largely due to the massive overhead of reading and writing the large intermediate attention matrix to the GPU's relatively slow High Bandwidth Memory (HBM). While prior approximate attention methods tried to reduce compute requirements, they often failed to achieve actual wall-clock speedups because they ignored these memory access (IO) overheads.

To solve this, FlashAttention implements an IO-aware exact attention algorithm that drastically minimizes HBM accesses using two key techniques:

  • Tiling: The algorithm splits the inputs (Queries, Keys, and Values) into blocks and loads them from slow HBM to the much faster, but smaller, on-chip SRAM. It then incrementally performs the softmax reduction block by block, avoiding the need to materialize the entire attention matrix on the HBM.
  • Recomputation: Instead of saving the massive $N \times N$ intermediate attention matrix for the backward pass (gradient calculation), FlashAttention stores only the softmax normalization statistics. During the backward pass, it uses these statistics to quickly recompute the attention matrix on-chip, which is much faster than reading it from HBM.

The authors also propose block-sparse FlashAttention, an extension that skips zero blocks in a sparse attention mask. This approximate attention algorithm further improves speed and reduces IO complexity by a factor proportional to the sparsity ratio.

Key Results & Impact:

  • Faster Wall-Clock Training: FlashAttention trains models significantly faster than standard implementations. It trained BERT-large 15% faster than the MLPerf 1.1 speed record and trained GPT-2 up to 3x faster than HuggingFace and Megatron-LM baselines.
  • Memory Efficiency & Longer Contexts: Because FlashAttention's memory footprint scales linearly with sequence length rather than quadratically, it uses significantly less memory (up to 20x more efficient than exact attention baselines).
  • Higher Quality Models: The ability to scale to longer contexts allows for improved model capabilities. It yielded better perplexity for GPT-2, a 6.4-point lift on long-document classification tasks, and enabled Transformers to achieve the first better-than-random performance on the extreme long-context Path-X (16K sequence length) and Path-256 (64K sequence length) challenges.
...more
View all episodesView all episodes
Download on the App Store

Learning GenAI via SOTA PapersBy Yun Wu