
Sign up to save your podcasts
Or


This is a classic review of a now old but yet still important paper, the original Flash Attention paper. We review this in light of advances in compiler technology.
The June 23, 2022 Stanford paper describes the original **FlashAttention**, an innovative, IO-aware algorithm designed to significantly enhance the efficiency of the attention mechanism in Transformer models by optimizing memory usage and access. Standard attention suffers from complexity that scales **quadratically** ($O(N^2)$) with sequence length ($N$) for both memory footprint and access to slow High Bandwidth Memory (HBM), which creates a performance bottleneck. FlashAttention overcomes this by employing **tiling and recomputation** within a single customized CUDA kernel, dramatically reducing the memory footprint to scale **linearly** ($O(N)$) and eliminating the quadratic term in HBM access complexity. While the algorithm does not reduce the total Floating Point Operations (FLOPs) and even slightly increases them due to recomputation, the massive reduction in slow memory transfers results in substantial **wall-clock runtime speedups** during both training and inference.
Source:
https://arxiv.org/pdf/2205.14135
By mcgrofThis is a classic review of a now old but yet still important paper, the original Flash Attention paper. We review this in light of advances in compiler technology.
The June 23, 2022 Stanford paper describes the original **FlashAttention**, an innovative, IO-aware algorithm designed to significantly enhance the efficiency of the attention mechanism in Transformer models by optimizing memory usage and access. Standard attention suffers from complexity that scales **quadratically** ($O(N^2)$) with sequence length ($N$) for both memory footprint and access to slow High Bandwidth Memory (HBM), which creates a performance bottleneck. FlashAttention overcomes this by employing **tiling and recomputation** within a single customized CUDA kernel, dramatically reducing the memory footprint to scale **linearly** ($O(N)$) and eliminating the quadratic term in HBM access complexity. While the algorithm does not reduce the total Floating Point Operations (FLOPs) and even slightly increases them due to recomputation, the massive reduction in slow memory transfers results in substantial **wall-clock runtime speedups** during both training and inference.
Source:
https://arxiv.org/pdf/2205.14135