Review date: 2026-05-24 Review author: Zhongzhu Zhou Paper reviewed: FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning Paper authors: Tri Dao (Princeton / Stanford) arXiv: 2307.08691 Status/Venue: arXiv preprint (July 2023); ICLR 2024
Short Answer
FlashAttention-2 is not a new algorithm for attention — it is a better implementation of the same exact-attention computation introduced in FlashAttention-1. FA1 eliminated the HBM traffic bottleneck through tiled online softmax, but still achieved only 25-40% of an A100’s theoretical throughput. The culprit: three separate inefficiencies in how GPU time was spent. FA2 targets each one with a surgical fix: (1) reduce non-matmul FLOPs by deferring the running normalization, (2) add sequence-length parallelism to increase SM occupancy for long sequences, and (3) switch warp layout from split-K to split-Q to eliminate inter-warp shared memory communication. Together these yield a 2× speedup over FA1, reaching 50-73% of A100 peak — nearly matching the efficiency of raw GEMM operations. The result enables 2× longer context windows at the same cost.
Prerequisites: What You Need to Know First
Before diving into the algorithmic changes, it is worth understanding the hardware and software context in detail. Many readers see “2× faster” and think it must involve a clever math trick. It does not — it is pure systems engineering on GPU hardware.
1. The GPU Memory Hierarchy
Modern NVIDIA GPUs have a two-level memory hierarchy that is central to understanding FlashAttention:
High Bandwidth Memory (HBM): The main GPU DRAM. An A100-80GB has 80 GB at ~2 TB/s bandwidth. Every PyTorch tensor you allocate lives here by default.
On-chip SRAM (shared memory): Each streaming multiprocessor (SM) has its own private SRAM. On the A100, each SM has 192 KB of SRAM, and there are 108 SMs total — giving roughly 20 MB of total SRAM across the GPU. The bandwidth to SRAM is approximately 19 TB/s, roughly 10× faster than HBM.
The catch: SRAM is tiny (192 KB per SM) and volatile (data is lost when the kernel finishes). You must explicitly load data from HBM to SRAM, compute, and write results back.
The fundamental constraint behind all FlashAttention variants: minimize round-trips to HBM. Every extra read or write to HBM costs 10× compared to what you could do in SRAM.
graph TD
CPU["Host CPU\n(System RAM ~100 GB/s)"]
subgraph A100["A100 GPU"]
HBM["HBM: 80 GB\nBandwidth: ~2 TB/s\n(PyTorch tensors live here)"]
subgraph SM["Streaming Multiprocessor SM (×108)"]
SRAM["On-chip SRAM: 192 KB per SM\nBandwidth: ~19 TB/s\n(10× faster than HBM!)"]
REG["Registers: ~255 KB per SM\n(fastest, private per thread)"]
TC["Tensor Cores\n(matmul: 312 TF/s FP16)\n(non-matmul: 19.5 TF/s FP32)"]
end
end
CPU -->|"PCIe ~64 GB/s"| HBM
HBM -->|"BOTTLENECK: each read/write costs!"| SRAM
SRAM --> REG
REG --> TC
style HBM fill:#ff9999
style SRAM fill:#99ff99
style TC fill:#9999ff
Figure 1: A100 GPU memory hierarchy. HBM is the bottleneck — 10× slower than SRAM. FlashAttention’s design principle is to minimize HBM round-trips by fusing attention into a single HBM pass using tiled computation in SRAM.
2. GPU Execution Model: Threads, Warps, Thread Blocks
GPUs execute code through a hierarchy of thread groups:
- Thread: The smallest unit. Owns registers. Executes one instruction stream.
- Warp: A group of exactly 32 threads that run in SIMD lockstep on the same SM. Threads within a warp can communicate via register-level shuffle instructions (very fast, no sync barrier needed) or via shared memory.
- Thread Block (CTA): Multiple warps (typically 4-8) sharing the same SRAM allocation. Warps communicate by reading/writing SRAM with explicit
__syncthreads()barriers between writes and reads. - Grid: All thread blocks. The GPU scheduler assigns each block to an available SM. If you launch 500 thread blocks but only have 108 SMs, the blocks queue up.
Occupancy is the fraction of SM resources actively in use: registers, SRAM, and warp slots. Low occupancy means idle SMs — you’re leaving hardware utilization on the table. FA2’s sequence-length parallelism (Section 5.2) is primarily a fix for low occupancy when batch size and head count are small.
graph TD
GRID["GPU Kernel Launch\n(Grid of thread blocks)"]
GRID --> B0["Thread Block 0\nassigned to SM 0"]
GRID --> B1["Thread Block 1\nassigned to SM 1"]
GRID --> BN["... Thread Block N\nassigned to SM K (mod 108)"]
B0 --> W0["Warp 0 (32 threads)\ncommunicate via shuffle (free)"]
B0 --> W1["Warp 1 (32 threads)"]
B0 --> W2["Warp 2 (32 threads)"]
B0 --> W3["Warp 3 (32 threads)"]
W0 -->|"read/write\n+ __syncthreads()"| SMEM["Shared Memory (SRAM)\n192 KB on A100\nshared by all warps in block"]
W1 --> SMEM
W2 --> SMEM
W3 --> SMEM
style SMEM fill:#ffff99
Figure 2: GPU thread hierarchy. Warps communicate via shared memory with synchronization barriers. FA2’s split-Q warp layout eliminates the need for inter-warp communication during the attention forward pass.
3. Standard Attention: The O(N²) Problem
Given inputs Q (queries), K (keys), V (values) each of shape where is sequence length and is head dimension, standard multi-head attention computes:
The problem is the intermediate matrices and . For (8k context) with FP16:
For a 40-head model with batch size 4: — just for attention intermediates. This exceeds a single A100’s 80 GB for large batches or large models.
The standard CUDA implementation is:
- GEMM: → write to HBM
- Load from HBM → compute softmax → write to HBM
- GEMM: , loading from HBM
That is 3 separate HBM round-trips for data. The operation is memory-bandwidth-bound, not compute-bound.
4. FlashAttention-1 Recap: IO-Awareness + Online Softmax
FlashAttention (Dao et al., NeurIPS 2022) introduced two classical ideas to eliminate the HBM traffic.
Tiling: Process attention in tiles. Load a block of Q rows of size to SRAM. For each such Q block, sweep through all blocks of K and V (each of size ), computing attention in tiles that fit in SRAM. Never materialize the full matrix.
Online softmax (Milakov & Gimelshein, 2018): The challenge: softmax couples the entire row — you need the global max and sum to normalize. Online softmax maintains running statistics as you process K/V blocks:
After processing the first two blocks and :
The final output after blocks:
This is exact softmax — no approximation. FA1’s HBM complexity drops from to where is SRAM size.
Why FA1 Still Fell Short: Despite being IO-optimal, FA1 reached only 25-40% of A100 peak TFLOPs/s. The forward pass was 30-50% of peak; the backward pass was even worse at 25-35%. In contrast, optimized GEMM (cuBLAS) reaches 80-90% of peak. The gap was not algorithmic but implementation-level — three distinct inefficiencies in how FA1 translated the algorithm to GPU instructions.
The Non-Matmul FLOP Gap: Root Cause Analysis
Modern GPUs have asymmetric compute units. The A100 provides:
- FP16/BF16 Tensor Cores (matrix multiply): 312 TFLOPs/s
- FP32 scalar units (non-matrix-multiply): 19.5 TFLOPs/s
Each non-matmul FLOP takes 16× as long as a matmul FLOP. To maintain >50% of theoretical peak, the kernel must spend close to 100% of its time on matmul. FA1’s inner loop includes significant non-matmul operations:
- Computing
m_new = max(m_old, rowmax(S_ij))— scalar max - Computing
exp(S_ij - m_new)— element-wise exp - Rescaling accumulated output:
O = diag(ℓ^{(j-1)}/ℓ^{(j)})^{-1} O + exp(...) V_j— diagonal multiply
These run on the slow scalar units while Tensor Cores sit idle. The more K/V blocks there are (i.e., the longer the sequence), the more times these operations execute.
flowchart TB
subgraph "FA1 Inner Loop (per K/V block j)"
direction TB
L1["Load K_j, V_j from HBM → SRAM"]
L2["MATMUL: S = Q_i × K_j^T\n(Tensor Cores, 312 TF/s)"]
L3["NON-MATMUL: m_new = max(m_old, rowmax(S))\n(FP32 scalar, 19.5 TF/s = 16× slower!)"]
L4["NON-MATMUL: P̃ = exp(S - m_new)\nℓ_new = exp(m - m_new) × ℓ + rowsum(P̃)\n(element-wise exp + reduction)"]
L5["MIXED: O_i = diag(ℓ_new)^{-1} × [diag(exp(m - m_new))^{-1} O_old + P̃ × V_j]\n(diagonal multiply = NON-MATMUL, then MATMUL P̃V)"]
L6["Update m ← m_new; ℓ ← ℓ_new"]
L1 --> L2 --> L3 --> L4 --> L5 --> L6
end
style L3 fill:#ff9999
style L4 fill:#ff9999
style L5 fill:#ffcccc
Figure 3: FA1 inner loop. Red nodes are non-matmul operations running on slow FP32 scalar units. Each inner loop iteration applies two diagonal rescalings — one for the max shift and one for the sum normalization. FA2 removes the per-step sum normalization.
The Three Core Improvements in FlashAttention-2
Improvement 1: Reduce Non-Matmul FLOPs
1a. Deferred Normalization (Forward Pass)
FA1 normalizes the accumulated output by at every inner-loop step. FA2 observes: you only need the final normalization once. Maintain an un-normalized accumulator throughout the inner loop, and divide by once at the very end of the outer-loop iteration.
FA1 update (applied at each inner step ):
This applies two rescalings: the max-shift correction and the sum normalization .
FA2 update (deferred normalization):
Only the max-shift correction is applied inside the loop. The sum normalization is deferred:
Correctness proof for the two-block case: Starting from :
After block 2:
Final output:
This matches standard softmax attention exactly. The normalization defers correctly because the max-shift correction preserves relative magnitudes of the numerator terms while accumulates the correct denominator.
FLOP savings: FA1 applies at each of inner steps — that is diagonal multiplications of non-matmul type. FA2 does this once. For , , we have : FA2 eliminates 127 extra per-step normalizations per Q row block. Each elimination saves a -vector operation on scalar units.
1b. Logsumexp Compression for the Backward Pass
FA1 stored two per-row statistics for the backward pass: the running max and the running sum . FA2 collapses these to a single scalar per row:
This is the logsumexp of row . In the backward pass, the attention probability is needed for gradient computation. With stored:
Storing instead of halves the number of backward-pass statistics written to HBM (from scalars to scalars per head). More importantly, it removes the two-step recomputation of softmax from and replaces it with a single-step recovery from , reducing non-matmul instructions in the backward pass.
Full FA2 Forward Pass Algorithm:
Algorithm 1: FlashAttention-2 Forward Pass
─────────────────────────────────────────────────────────────────────────────
Input: Q, K, V ∈ ℝ^{N×d} (in HBM); block sizes B_r, B_c
Output: O ∈ ℝ^{N×d} (in HBM); L ∈ ℝ^N (logsumexp, in HBM)
─────────────────────────────────────────────────────────────────────────────
1. Divide Q into T_r = ⌈N/B_r⌉ row-blocks Q_1 … Q_{T_r} (each B_r × d)
Divide K, V into T_c = ⌈N/B_c⌉ col-blocks K_1 … K_{T_c}, V_1 … V_{T_c}
2. for i = 1 … T_r do ← OUTER LOOP [embarrassingly parallel in FA2]
3. Load Q_i : HBM → SRAM
4. Initialize Õ_i = 0^{B_r×d}, ℓ_i = 0^{B_r}, m_i = −∞^{B_r}
5. for j = 1 … T_c do ← INNER LOOP
6. Load K_j, V_j : HBM → SRAM
7. S_ij = Q_i K_j^T ← MATMUL (Tensor Core)
8. m_i_new = max(m_i, rowmax(S_ij)) ← non-matmul (max)
9. P̃_ij = exp(S_ij − m_i_new) ← non-matmul (exp)
10. ℓ_i_new = exp(m_i − m_i_new)·ℓ_i + rowsum(P̃_ij) ← non-matmul
11. Õ_i = diag(exp(m_i − m_i_new))^{-1} Õ_i + P̃_ij V_j ← scale + MATMUL
─── FA2 KEY: no diag(ℓ)^{-1} here! ───
12. m_i ← m_i_new; ℓ_i ← ℓ_i_new
13. end for
14. O_i = diag(ℓ_i)^{-1} Õ_i ← ONE normalization per outer step
15. L_i = m_i + log(ℓ_i) ← Logsumexp for backward pass
16. Write O_i → HBM; Write L_i → HBM
17. end for
─────────────────────────────────────────────────────────────────────────────
Line 14 is the critical difference from FA1: the normalization by happens once per outer iteration instead of once per inner iteration.
FA2 Backward Pass Algorithm
Algorithm 2: FlashAttention-2 Backward Pass
─────────────────────────────────────────────────────────────────────────────
Input: Q, K, V, O, dO ∈ ℝ^{N×d} (HBM); L ∈ ℝ^N (HBM)
Output: dQ, dK, dV ∈ ℝ^{N×d} (HBM)
─────────────────────────────────────────────────────────────────────────────
Precompute: D = rowsum(dO ◦ O) ∈ ℝ^N
(needed for dS; D_i = ∑_k dO_ik · O_ik = the "scale" from softmax Jacobian)
1. for j = 1 … T_c do ← OUTER LOOP (over K/V column blocks)
2. Load K_j, V_j : HBM → SRAM
3. Initialize dK_j = 0, dV_j = 0
4. for i = 1 … T_r do ← INNER LOOP
5. Load Q_i, O_i, dO_i, L_i, D_i : HBM → SRAM
6. S_ij = Q_i K_j^T ← MATMUL
7. P_ij = exp(S_ij − L_i) ← recompute softmax from logsumexp L
(no need to have stored P; recompute from L in SRAM)
8. dV_j += P_ij^T · dO_i ← MATMUL
9. dP_ij = dO_i · V_j^T ← MATMUL
10. dS_ij = P_ij ◦ (dP_ij − D_i) ← non-matmul: softmax Jacobian
11. dQ_i += dS_ij · K_j ← MATMUL (atomic add to HBM)
12. dK_j += dS_ij^T · Q_i ← MATMUL
13. end for
14. Write dK_j, dV_j → HBM
15. end for
─────────────────────────────────────────────────────────────────────────────
Line 7 shows the role of logsumexp: we recompute on-chip from the stored scalar , without ever having stored the full matrix . Line 11 uses atomic adds because receives contributions from all column blocks — this is why the backward pass requires synchronization across thread blocks.
The softmax Jacobian at line 10 uses the identity: if , then where (pre-computed once in the vector).
Improvement 2: Sequence-Length Dimension Parallelism
FA1 launched thread blocks (batch size × heads). For long-context inference with small batches:
FA2 adds a third parallelism dimension: the outer-loop row blocks of Q. Since outer iterations are independent, FA2 assigns each row block to a separate thread block.
For , : . This multiplies the number of thread blocks by 128, dramatically increasing SM occupancy.
graph LR
subgraph FA1_par["FA1: 2D Parallelism"]
direction TB
TB1["Thread Block 0\n= Head 0, Batch 0\nProcesses ALL T_r=128 row iterations\n(one block per head)"]
TB2["Thread Block 1\n= Head 1, Batch 0"]
TBN["Thread Block 31\n= Head 31, Batch 0\n→ TOTAL: 32 blocks for 108 SMs"]
end
subgraph FA2_par["FA2: 3D Parallelism"]
direction TB
TB2_0["Thread Block 0\n= Head 0, Batch 0, Row block 0\n(processes 1 row tile vs ALL K/V)"]
TB2_1["Thread Block 1\n= Head 0, Batch 0, Row block 1"]
TB2_N["Thread Block 4095\n= Head 31, Batch 0, Row block 127\n→ TOTAL: 32×128=4096 blocks"]
end
FA1_par -->|"Long-context: B×H small,\nmany SMs idle"| LOWOC["Low occupancy"]
FA2_par -->|"T_r factor fills all SMs"| HIGHOC["High occupancy"]
style LOWOC fill:#ff9999
style HIGHOC fill:#99ff99
Figure 4: FA2 adds sequence-length parallelism. FA1’s 32 thread blocks leave 76 of 108 A100 SMs idle for a batch-size-1, 32-head model. FA2’s additional row-block dimension creates 4096 thread blocks, saturating all SMs.
Forward pass: The outer loop over row blocks is embarrassingly parallel. Each thread block loads its tile and independently processes all K/V column blocks. No communication between row-block thread blocks is needed.
Backward pass: Each outer iteration processes one column block of K and V, computing contributions to and from all Q row blocks. This also parallelizes over column blocks.
The complication: receives contributions from all column blocks . FA2 handles this with atomic adds:
Atomic adds serialize concurrent writes from different thread blocks to the same location. On modern GPUs, atomics to different cache lines are pipelined and efficient, so the overhead is modest.
This idea of sequence-length parallelism was first implemented by Phil Tillet at OpenAI in the Triton implementation of FlashAttention; FA2 formalizes and extends it to the CUDA kernel.
Improvement 3: Split-Q Warp Partitioning
The most subtle but impactful change in FA2 concerns how work is divided among the 4 warps within a single thread block. This addresses shared-memory traffic, not HBM traffic.
FA1’s Split-K Scheme
FA1 divided the key/value head dimension across warps. With 4 warps and head dimension :
- Warp 0 computes (first 32 columns of K), gets partial
- Warp 1 computes , gets partial
- Warps 2, 3 similarly for columns 64-96 and 96-128
Each warp then multiplies its partial with the corresponding to get a partial output. The partial outputs must be summed across warps to get the final .
Problem: The sum requires inter-warp communication through shared memory:
- Each warp writes its partial to SRAM.
- All warps synchronize (
__syncthreads()). - One designated warp (or all warps) reads all partial outputs and sums.
- Another sync before the result is consumed.
This is the “split-K” pattern, borrowed from standard GEMM. For standard matrix multiplication, split-K is fine because outputs don’t have softmax coupling. For attention with online softmax, it introduces an additional problem: each warp sees only a column slice of , but softmax is computed row-wise over the full row. So the warps cannot independently compute correct softmax — they must first share their partial slices, or share and statistics, adding more shared-memory communication.
FA2’s Split-Q Scheme
FA2 instead splits the query row dimension across warps. With 4 warps and row block size :
- Warp 0 processes Q rows 0-15 → computes , softmax, and output
- Warp 1 processes Q rows 16-31 → computes , softmax, and output
- Warps 2, 3 process rows 32-47 and 48-63
All warps access the same K and V tiles (loaded once into SRAM by the thread block), but each warp independently computes softmax for its own subset of query rows. Since softmax is row-wise, and each warp owns complete rows, there is no cross-warp dependency. No shared-memory writes, no sync barriers, no inter-warp communication.
graph TD
subgraph fa1warp["FA1: Split-K Warp Layout"]
direction LR
subgraph W0_K["Warp 0"]
qw0["Q_i (full B_r rows)"]
kw0["K_j[0:d/4] (column slice)"]
partial0["partial QK^T[0:d/4]"]
end
subgraph W1_K["Warp 1"]
qw1["Q_i (same)"]
kw1["K_j[d/4:d/2]"]
partial1["partial QK^T[d/4:d/2]"]
end
SMEM_K["SRAM: all 4 warps write partial\noutputs here + syncthreads\n(BOTTLENECK)"]
O_K["Final O_i (after sync + sum)"]
partial0 --> SMEM_K
partial1 --> SMEM_K
SMEM_K --> O_K
end
subgraph fa2warp["FA2: Split-Q Warp Layout"]
direction LR
subgraph W0_Q["Warp 0"]
qw0q["Q_i[0:B_r/4] (row slice)"]
kvall0["K_j, V_j (full, in SRAM)"]
out0["Full O_i[0:B_r/4]"]
end
subgraph W1_Q["Warp 1"]
qw1q["Q_i[B_r/4:B_r/2]"]
kvall1["K_j, V_j (same)"]
out1["Full O_i[B_r/4:B_r/2]"]
end
O_Q["Final O_i\n(no sync needed!)"]
kvall0 -.-> qw0q
kvall1 -.-> qw1q
out0 --> O_Q
out1 --> O_Q
end
style SMEM_K fill:#ff9999
style O_Q fill:#99ff99
Figure 5: FA1 split-K (top) vs FA2 split-Q (bottom). Split-K requires all warps to write partial results to shared memory and synchronize before summing. Split-Q assigns each warp complete rows — zero inter-warp communication. K/V are loaded once into shared memory and read by all warps.
Why split-Q works: Softmax is row-wise. If a warp owns complete rows of , it can compute independent online softmax for those rows by iterating over K/V column blocks — no information from other warps is needed. The warp’s , , and are private registers — no shared memory needed except for loading K and V (which are read, not written, so no sync is needed).
Block size tuning: FA2 manually tunes and for each head dimension and each device. The constraint:
(SRAM needed for , tiles, and the tile.) Typical values: , . Future work: auto-tuning via hardware performance counters.
Causal Masking Optimization
Autoregressive attention requires that token only attends to positions . In the block-tiled formulation, two cases arise:
Case 1 — Fully masked block: If all column indices in block exceed all row indices in row block (i.e., the entire block is strictly above the diagonal), then all entries become after masking. After softmax, the entire block contributes zero to the output. FA2 skips the GEMM and HBM load for this block entirely.
For large , approximately half the blocks are fully masked. Skipping them gives ~1.7-1.8× speedup versus unmasked attention.
Case 2 — Boundary block: The diagonal block where some entries are valid and some are masked. Only this one block per row requires explicit mask application. All blocks below the diagonal are fully unmasked.
graph LR
subgraph mask["N×N Causal Attention Matrix (tiled)"]
direction TB
ABOVE["Upper triangle blocks\n(j > i always masked)\n→ SKIP: no GEMM, no HBM load\n≈50% of all blocks"]
DIAG["Diagonal blocks\n(j ≈ i, partial mask)\n→ Apply causal mask explicitly\n1 per row block"]
BELOW["Lower triangle blocks\n(j < i, all unmasked)\n→ Full attention, no mask needed"]
end
ABOVE -->|"~50% skipped"| SAVE["~1.7-1.8× speedup\nvs non-causal attention"]
style ABOVE fill:#cccccc
style SAVE fill:#99ff99
style DIAG fill:#ffff99
style BELOW fill:#ccffcc
Figure 6: Causal masking in FA2’s tiled framework. Upper-triangular blocks are fully masked and skipped. Only diagonal (boundary) blocks require masking logic. Lower-triangular blocks run at full throughput.
Multi-Query and Grouped-Query Attention
In standard multi-head attention (MHA), every head has its own Q, K, V projections. For KV-cache-heavy inference, two variants reduce memory:
Multi-Query Attention (MQA, Shazeer 2019): All query heads share a single K and V head. The K/V cache shrinks by factor (number of heads).
Grouped-Query Attention (GQA, Ainslie et al. 2023): Query heads are organized into groups; each group shares one K/V head. interpolates between MQA () and MHA ().
FA2 supports both by manipulating tensor index strides rather than duplicating K/V data. For the -th query head in group , FA2 indexes into K at position (not ), effectively broadcasting the shared K head to all Q heads in the group — with no memory duplication.
In the backward pass, gradient contributions from all query heads in a group must be reduced into and for the single shared K/V head. FA2 applies this reduction via an extra summation after the backward kernel.
Empirical Results
4.1 Attention Kernel Throughput
All benchmarks on A100 80GB SXM4 GPU. Sequence length varies 512-16k; batch size fixed so total tokens = 16k; hidden dim = 2048.
Forward + backward speed, no causal mask, head dim 64 (A100):
| Method | 512 | 1k | 2k | 4k | 8k | 16k |
|---|---|---|---|---|---|---|
| PyTorch (standard) | 68 | 91 | 98 | 95 | 83 | OOM |
| FlashAttention-1 | 91 | 90 | 104 | 98 | 92 | 76 |
| FA Triton | 102 | 92 | 110 | 110 | 108 | 100 |
| FlashAttention-2 | 132 | 162 | 171 | 176 | 175 | 173 |
Units: TFLOPs/s. Peak A100 FP16: 312 TFLOPs/s.
Forward + backward speed, causal mask, head dim 128:
| Method | 512 | 1k | 2k | 4k | 8k | 16k |
|---|---|---|---|---|---|---|
| FlashAttention-1 | 35 | 49 | 59 | 65 | 68 | 71 |
| FlashAttention-2 | 127 | 148 | 192 | 193 | 196 | 203 |
FA2 with causal masking reaches 203 TFLOPs/s at 16k sequence length = 65% of A100 peak. The forward-pass-only peak reaches 224-227 TFLOPs/s = 72-73% of peak.
For reference, the best GEMM (cuBLAS) reaches ~80-90% of peak. FA2 is now within 10-20 percentage points of pure matmul efficiency.
H100 results (same code, no H100-specific features):
Without any H100 optimizations, FA2 already reaches 294-338 TFLOPs/s (H100 SXM5 peak = 989 TFLOPs/s FP16). This is “free” speedup from the higher H100 clock frequency. With H100-specific features (TMA, FP8), the authors expect 1.5-2× additional gain, bringing FA to ~500-600 TFLOPs/s on H100.
4.2 End-to-End Training Throughput
GPT-style models on 8× A100 80GB SXM4 GPUs:
| Model | Seq | No FlashAttn | FA1 | FA2 | Speedup (FA2 vs baseline) |
|---|---|---|---|---|---|
| GPT3-1.3B | 2k | 142 TF/s | 189 | 196 | 1.38× |
| GPT3-1.3B | 8k | 72 TF/s | 170 | 220 | 3.06× |
| GPT3-2.7B | 2k | 149 TF/s | 189 | 205 | 1.38× |
| GPT3-2.7B | 8k | 80 TF/s | 175 | 225 | 2.81× |
The 8k context case shows the largest gain because at long sequences the batch size shrinks (memory constraints), reducing batch × heads parallelism — precisely the scenario where FA2’s sequence-length parallelism and higher occupancy make the biggest difference.
At 225 TFLOPs/s for GPT3-2.7B with 8k context:
This is exceptionally high for a real training workload — close to the hardware ceiling when all overheads (data loading, optimizer steps, gradient communication) are factored in.
Design Analysis: Why / Alternative / Boundary Conditions
Why Defer Normalization Rather Than Removing It Entirely?
The question: Could we avoid all per-step normalization by processing all blocks, accumulating raw unnormalized sums, and normalizing once at the very end?
The answer: No — numerical overflow prevents this. Without the online max subtraction, for large values overflows FP16/FP32. The online max trick ( always ) is essential for numerical stability. FA2 keeps the per-step max correction (, which is always by construction) but removes the per-step sum normalization (). The max correction is cheap (scalar multiply); the sum normalization is what FA2 eliminates.
Boundary: The max correction step still executes at every inner loop iteration. For numerically flat inputs (all attention logits near-zero), the correction factor and could be skipped, but this requires a runtime branch that might hurt performance from branch divergence.
Why Split-Q Instead of Splitting Along Another Dimension?
The question: What if you split along the head dimension at the warp level instead of the row dimension?
The answer: Head-dimension split within a thread block doesn’t make sense — the thread block itself handles one attention head for one batch element. You can’t further subdivide the head dimension within that block’s computation without creating partial GEMM tiles that are too small for efficient Tensor Core utilization.
The only meaningful within-block splits are rows (Q dimension) or columns (K/V dimension) of the attention tile. Split-K (column split) is natural for standard GEMM but fails for attention because of row-wise softmax coupling — warps need to communicate to aggregate and statistics across columns. Split-Q (row split) gives each warp independent rows for which softmax is self-contained.
Alternative considered in the paper: An “inter-block split” for GQA/MQA where one thread block handles multiple Q heads sharing the same K/V. This adds another parallelism axis and could improve occupancy for models where H is very small. FA2 leaves this to future work.
Why Isn’t FA2 Used for All Sequence Lengths Uniformly?
Observation: The throughput numbers show FA2’s advantage is larger at longer sequences. At very short sequences (512 tokens), FA2 is ~2× faster than FA1, but the absolute throughput (132 vs 91 TFLOPs/s) still doesn’t reach peak because at short sequences, the overhead of kernel launch, block scheduling, and SRAM initialization is a larger fraction of total time.
Boundary condition: For sequences shorter than ~256 tokens, the standard attention implementation (PyTorch’s F.scaled_dot_product_attention with cuDNN backend) may match or beat FA2 because the matrices are small enough to fit in L2 cache, eliminating the HBM bandwidth bottleneck that FA2 is designed to address.
Relationship to FA1, FA3, and xformers
FA1 to FA2
Both use identical mathematical algorithms for tiled online softmax attention. FA2 changes only the implementation:
- Same: Tiling structure, online softmax recurrence, HBM IO complexity
- Different: Deferred normalization; sequence-length parallelism; split-Q warp layout
FA1 and FA2 produce identical numerical outputs (up to floating-point rounding).
FA2 to FA3 (Shah et al., ICLR 2025)
FA3 targets Hopper (H100/H200) GPUs, adding:
- TMA (Tensor Memory Accelerator): Hardware-assisted async tile loads, overlapping memory and compute
- Warp specialization: Some warps dedicated to loading (producers), others to computing (consumers)
- FP8 support: 2× throughput vs FP16
- WGMMA instructions: 4th-generation Tensor Core instructions for better utilization
FA3 reaches ~75% of H100 FP8 theoretical peak (~1,300+ TFLOPs/s effective), a 2× improvement over FA2 on H100.
FA2 vs xformers
Facebook’s xformers library (cutlass-based) was competitive with FA1 but consistently lags behind FA2 by ~10-20% in the benchmarks. The gap is due to xformers not adopting the split-Q warp layout.
Arithmetic Intensity and the Roofline Model
A principled way to understand where attention sits on the GPU’s performance landscape is the roofline model. For a GPU with compute throughput (FLOPs/s) and memory bandwidth (bytes/s), the maximum achievable performance for an operation is:
where is the arithmetic intensity (FLOPs per byte of memory transferred). The “roofline” has two regimes:
- Memory-bound (): performance limited by bandwidth
- Compute-bound (): performance limited by peak FLOPs
For A100: (FP16 matmul), , so the ridge point is:
Standard attention arithmetic intensity:
- FLOPs: (two GEMMs for and )
- Memory: (reading/writing and matrices to HBM)
- Intensity: FLOPs/byte (for FP16, with : 128/2 = 64 FLOPs/byte)
At , standard attention has intensity ~64 FLOPs/byte — below the ridge point of 156. Standard attention is memory-bound.
FlashAttention arithmetic intensity (FA1/FA2):
- FLOPs: same
- Memory: HBM accesses (Q, K, V, O tiles, no intermediates)
- Effective intensity: FLOPs/byte
For : intensity = 4096 FLOPs/byte — far above the ridge point. FA1/FA2 are compute-bound, not memory-bound. This is why they can approach peak Tensor Core throughput.
graph LR
subgraph roofline["Roofline Model (A100 FP16)"]
direction TB
MEM["Memory-bound region\n(slope = bandwidth 2 TB/s)\nStandard attention lives here\n(intensity ~64 FLOPs/byte)"]
RIDGE["Ridge point\n(156 FLOPs/byte)\nTransition point"]
COMPUTE["Compute-bound region\n(ceiling = 312 TF/s)\nFA1/FA2 live here\n(intensity = N ≫ 156)"]
end
MEM -->|"FA1/FA2 shift\nattention here!"| COMPUTE
style MEM fill:#ff9999
style COMPUTE fill:#99ff99
style RIDGE fill:#ffff99
Figure 7: Roofline model for A100. Standard attention is memory-bound (intensity < 156 FLOPs/byte). FA1/FA2 eliminate the N×N HBM traffic, making attention compute-bound. The remaining gap from 312 TF/s peak is due to non-matmul FLOPs (FA2’s fix) and warp inefficiency (FA2’s split-Q fix).
The residual gap between FA2 and the 312 TF/s theoretical peak is not memory bandwidth anymore — it’s the ~27-50% overhead from non-matmul ops (exp, max, reduction) and scheduling inefficiency. FA2 reduces this gap by 2× vs FA1 by attacking exactly these non-matmul operations and the warp layout.
Practical Implications for LLM Training and Inference
Context Length Scaling
The most direct practical impact: FlashAttention-2 makes long-context training tractable.
Before FA2, training a GPT-style model at 8k context cost ~3× more compute per token than at 2k context (due to the attention FLOP growth). With FA2’s sequence-length parallelism fully utilizing SM capacity:
| Context Length | Baseline GPU time (FLOPs) | FA2 real throughput | FA2 effective cost |
|---|---|---|---|
| 2k | 1× | ~196 TF/s | 1× |
| 8k | 16× (attention only) | ~220 TF/s (+12%) | ~1.1× per-token cost |
| 16k | 64× (attention only) | ~224 TF/s | similar |
The reason FA2 scales better than naively expected: at longer sequences, the batch size per GPU must decrease (memory pressure), and FA2’s sequence-length parallelism fills the resulting SM vacancy. The attention kernel doesn’t slow down proportional to .
Inference and KV Cache Impact
For autoregressive inference, the “decode” step attends one new token’s query against the full KV cache of length :
- Query shape: (single new token)
- KV cache:
With in the dimension, standard metrics differ: this is not matrix-matrix multiplication but matrix-vector multiplication. FA2’s tiling and Tensor Core utilization matter less here. For decode-phase inference:
- FA2 helps in the prefill phase (processing the prompt, tokens in the query)
- For decode phase, the bottleneck shifts to KV cache bandwidth — KV cache compression techniques (MQA, GQA, quantized KV) are more important
FA2 directly supports MQA and GQA (Section 6), enabling the smaller KV caches these variants provide while maintaining full-speed attention computation.
Integration with PyTorch
As of PyTorch 2.0+, torch.nn.functional.scaled_dot_product_attention automatically dispatches to FA2 when:
- Running on CUDA (Ampere+)
- Input is FP16 or BF16
- Sequence length is not too short (>128 tokens)
This makes FA2 effectively the default for all modern LLM training in PyTorch without any code changes. Users who need the latest optimizations (H100 FA3, custom MQA layouts, FP8) should use the flash-attn package directly.
Implementation Details and Block Size Tuning
Choosing Block Sizes and
The block sizes (Q rows per tile) and (K/V columns per tile) are critical hyperparameters that must be tuned per GPU and head dimension. The constraints are:
- SRAM capacity: All three tiles must simultaneously fit:
(Factor 2 for FP16; Q tile + K tile + V tile + S tile)
-
Tensor Core alignment: Both dimensions of each matmul tile should be multiples of 16 (or 8 for certain precisions).
-
Register pressure: Larger means more rows of , , to keep in registers. If registers spill to local memory (L1 cache), performance degrades severely.
-
Occupancy vs tile size tradeoff: Larger blocks → higher arithmetic intensity within the block (fewer HBM accesses per FLOP) but larger SRAM footprint → fewer thread blocks can reside on the SM simultaneously (lower occupancy).
For A100 with and 192 KB SRAM per SM, typical optimal values: or . FA2 uses manual lookup tables indexed by (head_dim, GPU architecture). FA3 and the Triton implementation add auto-tuning.
Logsumexp and Numerical Stability
An important subtlety: the logsumexp stored by FA2 is computed as:
where is the final running max for row , and is the final running sum-of-exponentials. This is numerically stable because:
- is a sum of terms all in , so
- — bounded, no overflow
- is the standard log-sum-exp, matching what you’d compute in FP64
In the backward pass, recomputing never overflows because (since ).
FA2 Algorithm vs FA1: Complete Comparison
| Property | FlashAttention-1 | FlashAttention-2 |
|---|---|---|
| HBM complexity | Same — unchanged | |
| Per-step normalize | Yes, at every inner step | No — deferred to outer-step end |
| Backward statistics | Store — 2 scalars/row | Store — 1 scalar/row |
| Sequence-length parallelism | No — outer loop sequential | Yes — each row block is a thread block |
| Warp layout (forward) | Split-K: warps own K/V column slices | Split-Q: warps own Q row slices |
| Inter-warp communication (fwd) | Required — shared memory write/sync | None needed |
| Backward dQ update | Sequential within thread block | Atomic adds across thread blocks |
| MQA/GQA support | No (external handling) | Yes — native index stride manipulation |
| Causal mask optimization | Skip fully-masked blocks | Same — unchanged |
| Throughput on A100 (fwd+bwd) | 25-40% peak | 50-73% peak |
| End-to-end GPT-2.7B 8k | 175 TF/s | 225 TF/s |
The key insight from this comparison: FA2 makes no changes to the outer mathematical structure — same tiling, same online softmax recurrence, same HBM complexity. Every improvement is in the implementation mapping from algorithm to GPU hardware.
Summary
FlashAttention-2 is a case study in precise systems engineering. Tri Dao identified three specific bottlenecks in FA1 — each with a measurable cost — and fixed each independently with minimal code complexity:
- Deferred normalization: removes extra non-matmul diagonal multiplications per row block
- Sequence-length parallelism: fills idle SMs when , multiplying occupancy by up to
- Split-Q warp layout: eliminates inter-warp shared-memory writes and sync barriers in the forward pass
The combined effect: 2× speedup over FA1, 50-73% of A100 peak, 225 TFLOPs/s for GPT-2.7B at 8k context. The implication: training 16k-context models costs the same as 8k-context training used to.
For practitioners: FA2 is available as pip install flash-attn and is the default backend for PyTorch’s torch.nn.functional.scaled_dot_product_attention on CUDA. For H100 GPUs, FA3 is worth switching to. For CPU or non-CUDA backends, standard attention or xformers is the right choice.
The mental model: GPU attention performance depends on three separate dimensions — IO traffic to HBM (fixed by FA1), compute utilization / occupancy (fixed by FA2’s parallelism), and warp-level efficiency (fixed by FA2’s split-Q). FA1 solved dimension one; FA2 solves dimensions two and three.
Reproducibility Notes
The paper reports benchmark numbers on A100 80GB SXM4, with CUDA 11.8, cuDNN 8.7, PyTorch 2.0. Key implementation variables:
- Block sizes used: head dim 64 → ; head dim 128 → (some configs use 128)
- Reference implementation:
github.com/Dao-AILab/flash-attention(Apache 2.0 license) - PyTorch integration: Available via
pip install flash-attn==2.x.x; also merged into PyTorch astorch.backends.cuda.flash_sdp_enabled()path - FLOP counting convention: The paper counts attention FLOPs as (forward) and (backward, 2.5× forward). Note: causal attention divides forward FLOPs by 2 in the per-second efficiency calculations but NOT in the end-to-end model FLOPs (follows Megatron-LM convention for consistency)
- MFU calculation: ; for GPT-2.7B at 8k context, FA2 reports 72% MFU on A100 (225/312)
The key result — ~2× speedup over FA1 — is robust across head dimensions, sequence lengths, and batch sizes. The absolute throughput numbers (TFLOPs/s) are sensitive to the specific A100 variant (SXM4 vs PCIe), driver version, and whether the GPU is at thermal throttle. Expect ±5-10% variation in practice.