FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

Review date: 2026-05-24 Review author: Zhongzhu Zhou Paper reviewed: FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning Paper authors: Tri Dao (Princeton / Stanford) arXiv: 2307.08691 Status/Venue: arXiv preprint (July 2023); ICLR 2024

Short Answer

FlashAttention-2 is not a new algorithm for attention — it is a better implementation of the same exact-attention computation introduced in FlashAttention-1. FA1 eliminated the O(N2)O(N^2) HBM traffic bottleneck through tiled online softmax, but still achieved only 25-40% of an A100’s theoretical throughput. The culprit: three separate inefficiencies in how GPU time was spent. FA2 targets each one with a surgical fix: (1) reduce non-matmul FLOPs by deferring the running normalization, (2) add sequence-length parallelism to increase SM occupancy for long sequences, and (3) switch warp layout from split-K to split-Q to eliminate inter-warp shared memory communication. Together these yield a 2× speedup over FA1, reaching 50-73% of A100 peak — nearly matching the efficiency of raw GEMM operations. The result enables 2× longer context windows at the same cost.

Prerequisites: What You Need to Know First

Before diving into the algorithmic changes, it is worth understanding the hardware and software context in detail. Many readers see “2× faster” and think it must involve a clever math trick. It does not — it is pure systems engineering on GPU hardware.

1. The GPU Memory Hierarchy

Modern NVIDIA GPUs have a two-level memory hierarchy that is central to understanding FlashAttention:

High Bandwidth Memory (HBM): The main GPU DRAM. An A100-80GB has 80 GB at ~2 TB/s bandwidth. Every PyTorch tensor you allocate lives here by default.

On-chip SRAM (shared memory): Each streaming multiprocessor (SM) has its own private SRAM. On the A100, each SM has 192 KB of SRAM, and there are 108 SMs total — giving roughly 20 MB of total SRAM across the GPU. The bandwidth to SRAM is approximately 19 TB/s, roughly 10× faster than HBM.

The catch: SRAM is tiny (192 KB per SM) and volatile (data is lost when the kernel finishes). You must explicitly load data from HBM to SRAM, compute, and write results back.

SRAM bandwidth:19 TB/s10×HBM bandwidth (2 TB/s)(1)\text{SRAM bandwidth}: 19 \text{ TB/s} \approx 10 \times \text{HBM bandwidth (2 TB/s)} \tag{1}

The fundamental constraint behind all FlashAttention variants: minimize round-trips to HBM. Every extra read or write to HBM costs 10× compared to what you could do in SRAM.

graph TD
    CPU["Host CPU\n(System RAM ~100 GB/s)"]
    subgraph A100["A100 GPU"]
        HBM["HBM: 80 GB\nBandwidth: ~2 TB/s\n(PyTorch tensors live here)"]
        subgraph SM["Streaming Multiprocessor SM (×108)"]
            SRAM["On-chip SRAM: 192 KB per SM\nBandwidth: ~19 TB/s\n(10× faster than HBM!)"]
            REG["Registers: ~255 KB per SM\n(fastest, private per thread)"]
            TC["Tensor Cores\n(matmul: 312 TF/s FP16)\n(non-matmul: 19.5 TF/s FP32)"]
        end
    end
    CPU -->|"PCIe ~64 GB/s"| HBM
    HBM -->|"BOTTLENECK: each read/write costs!"| SRAM
    SRAM --> REG
    REG --> TC
    style HBM fill:#ff9999
    style SRAM fill:#99ff99
    style TC fill:#9999ff

Figure 1: A100 GPU memory hierarchy. HBM is the bottleneck — 10× slower than SRAM. FlashAttention’s design principle is to minimize HBM round-trips by fusing attention into a single HBM pass using tiled computation in SRAM.

2. GPU Execution Model: Threads, Warps, Thread Blocks

GPUs execute code through a hierarchy of thread groups:

  • Thread: The smallest unit. Owns registers. Executes one instruction stream.
  • Warp: A group of exactly 32 threads that run in SIMD lockstep on the same SM. Threads within a warp can communicate via register-level shuffle instructions (very fast, no sync barrier needed) or via shared memory.
  • Thread Block (CTA): Multiple warps (typically 4-8) sharing the same SRAM allocation. Warps communicate by reading/writing SRAM with explicit __syncthreads() barriers between writes and reads.
  • Grid: All thread blocks. The GPU scheduler assigns each block to an available SM. If you launch 500 thread blocks but only have 108 SMs, the blocks queue up.

Occupancy is the fraction of SM resources actively in use: registers, SRAM, and warp slots. Low occupancy means idle SMs — you’re leaving hardware utilization on the table. FA2’s sequence-length parallelism (Section 5.2) is primarily a fix for low occupancy when batch size and head count are small.

graph TD
    GRID["GPU Kernel Launch\n(Grid of thread blocks)"]
    GRID --> B0["Thread Block 0\nassigned to SM 0"]
    GRID --> B1["Thread Block 1\nassigned to SM 1"]
    GRID --> BN["... Thread Block N\nassigned to SM K (mod 108)"]
    B0 --> W0["Warp 0 (32 threads)\ncommunicate via shuffle (free)"]
    B0 --> W1["Warp 1 (32 threads)"]
    B0 --> W2["Warp 2 (32 threads)"]
    B0 --> W3["Warp 3 (32 threads)"]
    W0 -->|"read/write\n+ __syncthreads()"| SMEM["Shared Memory (SRAM)\n192 KB on A100\nshared by all warps in block"]
    W1 --> SMEM
    W2 --> SMEM
    W3 --> SMEM
    style SMEM fill:#ffff99

Figure 2: GPU thread hierarchy. Warps communicate via shared memory with synchronization barriers. FA2’s split-Q warp layout eliminates the need for inter-warp communication during the attention forward pass.

3. Standard Attention: The O(N²) Problem

Given inputs Q (queries), K (keys), V (values) each of shape N×dN \times d where NN is sequence length and dd is head dimension, standard multi-head attention computes:

S=QKRN×N(2)S = QK^\top \in \mathbb{R}^{N \times N} \tag{2} P=softmax(S)RN×N(3)P = \text{softmax}(S) \in \mathbb{R}^{N \times N} \tag{3} O=PVRN×d(4)O = PV \in \mathbb{R}^{N \times d} \tag{4}

The problem is the intermediate N×NN \times N matrices SS and PP. For N=8192N = 8192 (8k context) with FP16:

S=P=81922×2 bytes=128 MB per head per sample(5)|S| = |P| = 8192^2 \times 2 \text{ bytes} = 128 \text{ MB per head per sample} \tag{5}

For a 40-head model with batch size 4: 128 MB×40×4=20 GB128 \text{ MB} \times 40 \times 4 = 20 \text{ GB} — just for attention intermediates. This exceeds a single A100’s 80 GB for large batches or large models.

The standard CUDA implementation is:

  1. GEMM: S=QKS = QK^\top → write SS to HBM
  2. Load SS from HBM → compute softmax → write PP to HBM
  3. GEMM: O=PVO = PV, loading PP from HBM

That is 3 separate HBM round-trips for O(N2)O(N^2) data. The operation is memory-bandwidth-bound, not compute-bound.

4. FlashAttention-1 Recap: IO-Awareness + Online Softmax

FlashAttention (Dao et al., NeurIPS 2022) introduced two classical ideas to eliminate the O(N2)O(N^2) HBM traffic.

Tiling: Process attention in tiles. Load a block of Q rows of size BrB_r to SRAM. For each such Q block, sweep through all blocks of K and V (each of size BcB_c), computing attention in tiles that fit in SRAM. Never materialize the full N×NN \times N matrix.

Online softmax (Milakov & Gimelshein, 2018): The challenge: softmax couples the entire row — you need the global max and sum to normalize. Online softmax maintains running statistics as you process K/V blocks:

After processing the first two blocks S(1)S^{(1)} and S(2)S^{(2)}:

m(1)=rowmax(S(1))(6)m^{(1)} = \text{rowmax}(S^{(1)}) \tag{6} (1)=rowsum(eS(1)m(1))(7)\ell^{(1)} = \text{rowsum}(e^{S^{(1)} - m^{(1)}}) \tag{7} m(2)=max(m(1),rowmax(S(2)))(8)m^{(2)} = \max(m^{(1)}, \text{rowmax}(S^{(2)})) \tag{8} (2)=em(1)m(2)(1)+rowsum(eS(2)m(2))(9)\ell^{(2)} = e^{m^{(1)}-m^{(2)}} \ell^{(1)} + \text{rowsum}(e^{S^{(2)}-m^{(2)}}) \tag{9}

The final output after TcT_c blocks:

O=diag((Tc))1j=1TceS(j)m(Tc)V(j)(10)O = \text{diag}(\ell^{(T_c)})^{-1} \sum_{j=1}^{T_c} e^{S^{(j)} - m^{(T_c)}} V^{(j)} \tag{10}

This is exact softmax — no approximation. FA1’s HBM complexity drops from O(N2)O(N^2) to O(N2d/M)O(N^2 d / M) where MM is SRAM size.

Why FA1 Still Fell Short: Despite being IO-optimal, FA1 reached only 25-40% of A100 peak TFLOPs/s. The forward pass was 30-50% of peak; the backward pass was even worse at 25-35%. In contrast, optimized GEMM (cuBLAS) reaches 80-90% of peak. The gap was not algorithmic but implementation-level — three distinct inefficiencies in how FA1 translated the algorithm to GPU instructions.

The Non-Matmul FLOP Gap: Root Cause Analysis

Modern GPUs have asymmetric compute units. The A100 provides:

  • FP16/BF16 Tensor Cores (matrix multiply): 312 TFLOPs/s
  • FP32 scalar units (non-matrix-multiply): 19.5 TFLOPs/s
Effective cost ratio: 31219.5=16× cheaper for matmul than non-matmul(11)\text{Effective cost ratio: } \frac{312}{19.5} = 16 \times \text{ cheaper for matmul than non-matmul} \tag{11}

Each non-matmul FLOP takes 16× as long as a matmul FLOP. To maintain >50% of theoretical peak, the kernel must spend close to 100% of its time on matmul. FA1’s inner loop includes significant non-matmul operations:

  • Computing m_new = max(m_old, rowmax(S_ij)) — scalar max
  • Computing exp(S_ij - m_new) — element-wise exp
  • Rescaling accumulated output: O = diag(ℓ^{(j-1)}/ℓ^{(j)})^{-1} O + exp(...) V_j — diagonal multiply

These run on the slow scalar units while Tensor Cores sit idle. The more K/V blocks there are (i.e., the longer the sequence), the more times these operations execute.

flowchart TB
    subgraph "FA1 Inner Loop (per K/V block j)"
        direction TB
        L1["Load K_j, V_j from HBM → SRAM"]
        L2["MATMUL: S = Q_i × K_j^T\n(Tensor Cores, 312 TF/s)"]
        L3["NON-MATMUL: m_new = max(m_old, rowmax(S))\n(FP32 scalar, 19.5 TF/s = 16× slower!)"]
        L4["NON-MATMUL: P̃ = exp(S - m_new)\nℓ_new = exp(m - m_new) × ℓ + rowsum(P̃)\n(element-wise exp + reduction)"]
        L5["MIXED: O_i = diag(ℓ_new)^{-1} × [diag(exp(m - m_new))^{-1} O_old + P̃ × V_j]\n(diagonal multiply = NON-MATMUL, then MATMUL P̃V)"]
        L6["Update m ← m_new; ℓ ← ℓ_new"]
        L1 --> L2 --> L3 --> L4 --> L5 --> L6
    end
    style L3 fill:#ff9999
    style L4 fill:#ff9999
    style L5 fill:#ffcccc

Figure 3: FA1 inner loop. Red nodes are non-matmul operations running on slow FP32 scalar units. Each inner loop iteration applies two diagonal rescalings — one for the max shift and one for the sum normalization. FA2 removes the per-step sum normalization.

The Three Core Improvements in FlashAttention-2

Improvement 1: Reduce Non-Matmul FLOPs

1a. Deferred Normalization (Forward Pass)

FA1 normalizes the accumulated output OO by diag((j))1\text{diag}(\ell^{(j)})^{-1} at every inner-loop step. FA2 observes: you only need the final normalization once. Maintain an un-normalized accumulator O~\tilde{O} throughout the inner loop, and divide by \ell once at the very end of the outer-loop iteration.

FA1 update (applied at each inner step jj):

O(j)=diag((j))1[em(j1)m(j)(j1)O(j1)+eS(j)m(j)V(j)](12)O^{(j)} = \text{diag}(\ell^{(j)})^{-1} \left[ e^{m^{(j-1)}-m^{(j)}} \ell^{(j-1)} O^{(j-1)} + e^{S^{(j)}-m^{(j)}} V^{(j)} \right] \tag{12}

This applies two rescalings: the max-shift correction em(j1)m(j)e^{m^{(j-1)}-m^{(j)}} and the sum normalization diag((j))1\text{diag}(\ell^{(j)})^{-1}.

FA2 update (deferred normalization):

O~(j)=diag(em(j1)m(j))1O~(j1)+eS(j)m(j)V(j)(13)\tilde{O}^{(j)} = \text{diag}(e^{m^{(j-1)}-m^{(j)}})^{-1} \tilde{O}^{(j-1)} + e^{S^{(j)}-m^{(j)}} V^{(j)} \tag{13}

Only the max-shift correction is applied inside the loop. The sum normalization is deferred:

O=diag((Tc))1O~(Tc)(once, after inner loop)(14)O = \text{diag}(\ell^{(T_c)})^{-1} \tilde{O}^{(T_c)} \quad \text{(once, after inner loop)} \tag{14}

Correctness proof for the two-block case: Starting from O~(0)=0\tilde{O}^{(0)} = 0:

O~(1)=eS(1)m(1)V(1)(15)\tilde{O}^{(1)} = e^{S^{(1)}-m^{(1)}} V^{(1)} \tag{15}

After block 2:

O~(2)=em(1)m(2)O~(1)+eS(2)m(2)V(2)=eS(1)m(2)V(1)+eS(2)m(2)V(2)(16)\tilde{O}^{(2)} = e^{m^{(1)}-m^{(2)}} \tilde{O}^{(1)} + e^{S^{(2)}-m^{(2)}} V^{(2)} = e^{S^{(1)}-m^{(2)}} V^{(1)} + e^{S^{(2)}-m^{(2)}} V^{(2)} \tag{16}

Final output:

O=O~(2)(2)=eS(1)mV(1)+eS(2)mV(2)eS(1)m1+eS(2)m1=softmax ⁣([S(1)S(2)])[V(1)V(2)](17)O = \frac{\tilde{O}^{(2)}}{\ell^{(2)}} = \frac{e^{S^{(1)}-m} V^{(1)} + e^{S^{(2)}-m} V^{(2)}}{e^{S^{(1)}-m}\mathbf{1} + e^{S^{(2)}-m}\mathbf{1}} = \text{softmax}\!\left(\begin{bmatrix}S^{(1)} & S^{(2)}\end{bmatrix}\right)\begin{bmatrix}V^{(1)} \\ V^{(2)}\end{bmatrix} \tag{17} \checkmark

This matches standard softmax attention exactly. The normalization defers correctly because the max-shift correction preserves relative magnitudes of the numerator terms while \ell accumulates the correct denominator.

FLOP savings: FA1 applies diag()1\text{diag}(\ell)^{-1} at each of TcT_c inner steps — that is TcT_c diagonal multiplications of non-matmul type. FA2 does this once. For N=8192N=8192, Bc=64B_c=64, we have Tc=128T_c = 128: FA2 eliminates 127 extra per-step normalizations per Q row block. Each elimination saves a BrB_r-vector operation on scalar units.

1b. Logsumexp Compression for the Backward Pass

FA1 stored two per-row statistics for the backward pass: the running max mm and the running sum \ell. FA2 collapses these to a single scalar per row:

L=m+log()=log(jeSij)(18)L = m + \log(\ell) = \log\left(\sum_j e^{S_{ij}}\right) \tag{18}

This is the logsumexp of row ii. In the backward pass, the attention probability PijP_{ij} is needed for gradient computation. With LiL_i stored:

Pij=eSijLi=eSijkeSik(19)P_{ij} = e^{S_{ij} - L_i} = \frac{e^{S_{ij}}}{\sum_k e^{S_{ik}}} \tag{19}

Storing LL instead of (m,)(m, \ell) halves the number of backward-pass statistics written to HBM (from 2N2N scalars to NN scalars per head). More importantly, it removes the two-step recomputation of softmax from (m,)(m, \ell) and replaces it with a single-step recovery from LL, reducing non-matmul instructions in the backward pass.

Full FA2 Forward Pass Algorithm:

Algorithm 1: FlashAttention-2 Forward Pass
─────────────────────────────────────────────────────────────────────────────
Input:  Q, K, V ∈ ℝ^{N×d} (in HBM); block sizes B_r, B_c
Output: O ∈ ℝ^{N×d} (in HBM); L ∈ ℝ^N (logsumexp, in HBM)
─────────────────────────────────────────────────────────────────────────────
1.  Divide Q into T_r = ⌈N/B_r⌉ row-blocks Q_1 … Q_{T_r}  (each B_r × d)
    Divide K, V into T_c = ⌈N/B_c⌉ col-blocks K_1 … K_{T_c}, V_1 … V_{T_c}

2.  for i = 1 … T_r do                    ← OUTER LOOP  [embarrassingly parallel in FA2]
3.      Load Q_i : HBM → SRAM
4.      Initialize Õ_i = 0^{B_r×d},  ℓ_i = 0^{B_r},  m_i = −∞^{B_r}

5.      for j = 1 … T_c do                ← INNER LOOP
6.          Load K_j, V_j : HBM → SRAM
7.          S_ij = Q_i K_j^T              ← MATMUL  (Tensor Core)
8.          m_i_new = max(m_i, rowmax(S_ij))    ← non-matmul (max)
9.          P̃_ij  = exp(S_ij − m_i_new)         ← non-matmul (exp)
10.         ℓ_i_new = exp(m_i − m_i_new)·ℓ_i + rowsum(P̃_ij)  ← non-matmul
11.         Õ_i = diag(exp(m_i − m_i_new))^{-1} Õ_i + P̃_ij V_j  ← scale + MATMUL
            ─── FA2 KEY: no diag(ℓ)^{-1} here! ───
12.         m_i ← m_i_new;  ℓ_i ← ℓ_i_new
13.     end for

14.     O_i = diag(ℓ_i)^{-1} Õ_i        ← ONE normalization per outer step
15.     L_i = m_i + log(ℓ_i)             ← Logsumexp for backward pass
16.     Write O_i → HBM;  Write L_i → HBM
17. end for
─────────────────────────────────────────────────────────────────────────────

Line 14 is the critical difference from FA1: the normalization by diag()1\text{diag}(\ell)^{-1} happens once per outer iteration instead of once per inner iteration.

FA2 Backward Pass Algorithm

Algorithm 2: FlashAttention-2 Backward Pass
─────────────────────────────────────────────────────────────────────────────
Input:  Q, K, V, O, dO ∈ ℝ^{N×d} (HBM); L ∈ ℝ^N (HBM)
Output: dQ, dK, dV ∈ ℝ^{N×d} (HBM)
─────────────────────────────────────────────────────────────────────────────
Precompute:  D = rowsum(dO ◦ O) ∈ ℝ^N
  (needed for dS; D_i = ∑_k dO_ik · O_ik = the "scale" from softmax Jacobian)

1.  for j = 1 … T_c do                    ← OUTER LOOP (over K/V column blocks)
2.      Load K_j, V_j : HBM → SRAM
3.      Initialize dK_j = 0, dV_j = 0

4.      for i = 1 … T_r do                ← INNER LOOP
5.          Load Q_i, O_i, dO_i, L_i, D_i : HBM → SRAM
6.          S_ij = Q_i K_j^T                ← MATMUL
7.          P_ij = exp(S_ij − L_i)          ← recompute softmax from logsumexp L
            (no need to have stored P; recompute from L in SRAM)
8.          dV_j += P_ij^T · dO_i           ← MATMUL
9.          dP_ij = dO_i · V_j^T            ← MATMUL
10.         dS_ij = P_ij ◦ (dP_ij − D_i)   ← non-matmul: softmax Jacobian
11.         dQ_i += dS_ij · K_j             ← MATMUL  (atomic add to HBM)
12.         dK_j += dS_ij^T · Q_i           ← MATMUL
13.     end for

14.     Write dK_j, dV_j → HBM
15. end for
─────────────────────────────────────────────────────────────────────────────

Line 7 shows the role of logsumexp: we recompute Pij=exp(SijLi)P_{ij} = \exp(S_{ij} - L_i) on-chip from the stored scalar LiL_i, without ever having stored the full N×NN \times N matrix PP. Line 11 uses atomic adds because dQidQ_i receives contributions from all column blocks — this is why the backward pass requires synchronization across thread blocks.

The softmax Jacobian at line 10 uses the identity: if p=softmax(s)p = \text{softmax}(s), then ds=p(dpkpkdpk)=p(dpD)ds = p \circ (dp - \sum_k p_k \, dp_k) = p \circ (dp - D) where D=p,dpD = \langle p, dp \rangle (pre-computed once in the DD vector).

Improvement 2: Sequence-Length Dimension Parallelism

FA1 launched B×HB \times H thread blocks (batch size × heads). For long-context inference with small batches:

B=1,  H=32    32 thread blocks108 SMs    76 SMs idle(20)B=1, \; H=32 \implies 32 \text{ thread blocks} \ll 108 \text{ SMs} \implies 76 \text{ SMs idle} \tag{20}

FA2 adds a third parallelism dimension: the outer-loop row blocks of Q. Since outer iterations are independent, FA2 assigns each row block QiQ_i to a separate thread block.

FA2 thread blocks: B×H×TrwhereTr=N/Br(21)\text{FA2 thread blocks: } B \times H \times T_r \quad \text{where} \quad T_r = \lceil N/B_r \rceil \tag{21}

For N=8192N=8192, Br=64B_r=64: Tr=128T_r = 128. This multiplies the number of thread blocks by 128, dramatically increasing SM occupancy.

graph LR
    subgraph FA1_par["FA1: 2D Parallelism"]
        direction TB
        TB1["Thread Block 0\n= Head 0, Batch 0\nProcesses ALL T_r=128 row iterations\n(one block per head)"]
        TB2["Thread Block 1\n= Head 1, Batch 0"]
        TBN["Thread Block 31\n= Head 31, Batch 0\n→ TOTAL: 32 blocks for 108 SMs"]
    end
    subgraph FA2_par["FA2: 3D Parallelism"]
        direction TB
        TB2_0["Thread Block 0\n= Head 0, Batch 0, Row block 0\n(processes 1 row tile vs ALL K/V)"]
        TB2_1["Thread Block 1\n= Head 0, Batch 0, Row block 1"]
        TB2_N["Thread Block 4095\n= Head 31, Batch 0, Row block 127\n→ TOTAL: 32×128=4096 blocks"]
    end
    FA1_par -->|"Long-context: B×H small,\nmany SMs idle"| LOWOC["Low occupancy"]
    FA2_par -->|"T_r factor fills all SMs"| HIGHOC["High occupancy"]
    style LOWOC fill:#ff9999
    style HIGHOC fill:#99ff99

Figure 4: FA2 adds sequence-length parallelism. FA1’s 32 thread blocks leave 76 of 108 A100 SMs idle for a batch-size-1, 32-head model. FA2’s additional row-block dimension creates 4096 thread blocks, saturating all SMs.

Forward pass: The outer loop over row blocks is embarrassingly parallel. Each thread block loads its QiQ_i tile and independently processes all TcT_c K/V column blocks. No communication between row-block thread blocks is needed.

Backward pass: Each outer iteration processes one column block of K and V, computing contributions to dKjdK_j and dVjdV_j from all Q row blocks. This also parallelizes over TcT_c column blocks.

The complication: dQidQ_i receives contributions from all column blocks j=1Tcj=1\ldots T_c. FA2 handles this with atomic adds:

dQi+=dSijKjj(22)dQ_i \mathrel{+}= dS_{ij} K_j \quad \forall j \tag{22}

Atomic adds serialize concurrent writes from different thread blocks to the same dQdQ location. On modern GPUs, atomics to different cache lines are pipelined and efficient, so the overhead is modest.

This idea of sequence-length parallelism was first implemented by Phil Tillet at OpenAI in the Triton implementation of FlashAttention; FA2 formalizes and extends it to the CUDA kernel.

Improvement 3: Split-Q Warp Partitioning

The most subtle but impactful change in FA2 concerns how work is divided among the 4 warps within a single thread block. This addresses shared-memory traffic, not HBM traffic.

FA1’s Split-K Scheme

FA1 divided the key/value head dimension across warps. With 4 warps and head dimension d=128d=128:

  • Warp 0 computes QiKj,0:32Q_i K_{j,0:32}^\top (first 32 columns of K), gets partial Sij,cols 0:32S_{ij,\text{cols }0:32}
  • Warp 1 computes QiKj,32:64Q_i K_{j,32:64}^\top, gets partial Sij,cols 32:64S_{ij,\text{cols }32:64}
  • Warps 2, 3 similarly for columns 64-96 and 96-128

Each warp then multiplies its partial PsliceP_{\text{slice}} with the corresponding VsliceV_{\text{slice}} to get a partial output. The partial outputs must be summed across warps to get the final OO.

Problem: The sum requires inter-warp communication through shared memory:

  1. Each warp writes its partial OpartialO_\text{partial} to SRAM.
  2. All warps synchronize (__syncthreads()).
  3. One designated warp (or all warps) reads all partial outputs and sums.
  4. Another sync before the result is consumed.

This is the “split-K” pattern, borrowed from standard GEMM. For standard matrix multiplication, split-K is fine because outputs don’t have softmax coupling. For attention with online softmax, it introduces an additional problem: each warp sees only a column slice of SijS_{ij}, but softmax is computed row-wise over the full row. So the warps cannot independently compute correct softmax — they must first share their partial SS slices, or share mm and \ell statistics, adding more shared-memory communication.

FA2’s Split-Q Scheme

FA2 instead splits the query row dimension across warps. With 4 warps and row block size Br=64B_r=64:

  • Warp 0 processes Q rows 0-15 → computes Sij(0)S^{(0)}_{ij}, softmax, and output Oi(0)O^{(0)}_i
  • Warp 1 processes Q rows 16-31 → computes Sij(1)S^{(1)}_{ij}, softmax, and output Oi(1)O^{(1)}_i
  • Warps 2, 3 process rows 32-47 and 48-63

All warps access the same K and V tiles (loaded once into SRAM by the thread block), but each warp independently computes softmax for its own subset of query rows. Since softmax is row-wise, and each warp owns complete rows, there is no cross-warp dependency. No shared-memory writes, no sync barriers, no inter-warp communication.

graph TD
    subgraph fa1warp["FA1: Split-K Warp Layout"]
        direction LR
        subgraph W0_K["Warp 0"]
            qw0["Q_i (full B_r rows)"]
            kw0["K_j[0:d/4] (column slice)"]
            partial0["partial QK^T[0:d/4]"]
        end
        subgraph W1_K["Warp 1"]
            qw1["Q_i (same)"]
            kw1["K_j[d/4:d/2]"]
            partial1["partial QK^T[d/4:d/2]"]
        end
        SMEM_K["SRAM: all 4 warps write partial\noutputs here + syncthreads\n(BOTTLENECK)"]
        O_K["Final O_i (after sync + sum)"]
        partial0 --> SMEM_K
        partial1 --> SMEM_K
        SMEM_K --> O_K
    end
    subgraph fa2warp["FA2: Split-Q Warp Layout"]
        direction LR
        subgraph W0_Q["Warp 0"]
            qw0q["Q_i[0:B_r/4] (row slice)"]
            kvall0["K_j, V_j (full, in SRAM)"]
            out0["Full O_i[0:B_r/4]"]
        end
        subgraph W1_Q["Warp 1"]
            qw1q["Q_i[B_r/4:B_r/2]"]
            kvall1["K_j, V_j (same)"]
            out1["Full O_i[B_r/4:B_r/2]"]
        end
        O_Q["Final O_i\n(no sync needed!)"]
        kvall0 -.-> qw0q
        kvall1 -.-> qw1q
        out0 --> O_Q
        out1 --> O_Q
    end
    style SMEM_K fill:#ff9999
    style O_Q fill:#99ff99

Figure 5: FA1 split-K (top) vs FA2 split-Q (bottom). Split-K requires all warps to write partial results to shared memory and synchronize before summing. Split-Q assigns each warp complete rows — zero inter-warp communication. K/V are loaded once into shared memory and read by all warps.

Why split-Q works: Softmax is row-wise. If a warp owns complete rows of QiQ_i, it can compute independent online softmax for those rows by iterating over K/V column blocks — no information from other warps is needed. The warp’s O~rows\tilde{O}_\text{rows}, mrowsm_\text{rows}, and rows\ell_\text{rows} are private registers — no shared memory needed except for loading K and V (which are read, not written, so no sync is needed).

Block size tuning: FA2 manually tunes BrB_r and BcB_c for each head dimension d{64,128}d \in \{64, 128\} and each device. The constraint:

Brd+Bcd+BrBcMSRAM=192 KB per SM(23)B_r \cdot d + B_c \cdot d + B_r \cdot B_c \leq M_\text{SRAM} = 192 \text{ KB per SM} \tag{23}

(SRAM needed for QiQ_i, Kj/VjK_j/V_j tiles, and the SijS_{ij} tile.) Typical values: Br{64,128}B_r \in \{64, 128\}, Bc{64,128}B_c \in \{64, 128\}. Future work: auto-tuning via hardware performance counters.

Causal Masking Optimization

Autoregressive attention requires that token ii only attends to positions jij \leq i. In the block-tiled formulation, two cases arise:

Case 1 — Fully masked block: If all column indices in block jj exceed all row indices in row block ii (i.e., the entire block is strictly above the diagonal), then all entries become -\infty after masking. After softmax, the entire block contributes zero to the output. FA2 skips the GEMM and HBM load for this block entirely.

For large NN, approximately half the blocks are fully masked. Skipping them gives ~1.7-1.8× speedup versus unmasked attention.

Case 2 — Boundary block: The diagonal block where some entries are valid and some are masked. Only this one block per row requires explicit mask application. All blocks below the diagonal are fully unmasked.

graph LR
    subgraph mask["N×N Causal Attention Matrix (tiled)"]
        direction TB
        ABOVE["Upper triangle blocks\n(j > i always masked)\n→ SKIP: no GEMM, no HBM load\n≈50% of all blocks"]
        DIAG["Diagonal blocks\n(j ≈ i, partial mask)\n→ Apply causal mask explicitly\n1 per row block"]
        BELOW["Lower triangle blocks\n(j < i, all unmasked)\n→ Full attention, no mask needed"]
    end
    ABOVE -->|"~50% skipped"| SAVE["~1.7-1.8× speedup\nvs non-causal attention"]
    style ABOVE fill:#cccccc
    style SAVE fill:#99ff99
    style DIAG fill:#ffff99
    style BELOW fill:#ccffcc

Figure 6: Causal masking in FA2’s tiled framework. Upper-triangular blocks are fully masked and skipped. Only diagonal (boundary) blocks require masking logic. Lower-triangular blocks run at full throughput.

Multi-Query and Grouped-Query Attention

In standard multi-head attention (MHA), every head has its own Q, K, V projections. For KV-cache-heavy inference, two variants reduce memory:

Multi-Query Attention (MQA, Shazeer 2019): All query heads share a single K and V head. The K/V cache shrinks by factor HH (number of heads).

Grouped-Query Attention (GQA, Ainslie et al. 2023): Query heads are organized into GG groups; each group shares one K/V head. GG interpolates between MQA (G=1G=1) and MHA (G=HG=H).

FA2 supports both by manipulating tensor index strides rather than duplicating K/V data. For the ii-th query head in group gg, FA2 indexes into K at position gg (not ii), effectively broadcasting the shared K head to all Q heads in the group — with no memory duplication.

In the backward pass, gradient contributions from all query heads in a group must be reduced into dKdK and dVdV for the single shared K/V head. FA2 applies this reduction via an extra summation after the backward kernel.

Empirical Results

4.1 Attention Kernel Throughput

All benchmarks on A100 80GB SXM4 GPU. Sequence length varies 512-16k; batch size fixed so total tokens = 16k; hidden dim = 2048.

Forward + backward speed, no causal mask, head dim 64 (A100):

Method5121k2k4k8k16k
PyTorch (standard)6891989583OOM
FlashAttention-19190104989276
FA Triton10292110110108100
FlashAttention-2132162171176175173

Units: TFLOPs/s. Peak A100 FP16: 312 TFLOPs/s.

Forward + backward speed, causal mask, head dim 128:

Method5121k2k4k8k16k
FlashAttention-1354959656871
FlashAttention-2127148192193196203

FA2 with causal masking reaches 203 TFLOPs/s at 16k sequence length = 65% of A100 peak. The forward-pass-only peak reaches 224-227 TFLOPs/s = 72-73% of peak.

For reference, the best GEMM (cuBLAS) reaches ~80-90% of peak. FA2 is now within 10-20 percentage points of pure matmul efficiency.

H100 results (same code, no H100-specific features):

Without any H100 optimizations, FA2 already reaches 294-338 TFLOPs/s (H100 SXM5 peak = 989 TFLOPs/s FP16). This is “free” speedup from the higher H100 clock frequency. With H100-specific features (TMA, FP8), the authors expect 1.5-2× additional gain, bringing FA to ~500-600 TFLOPs/s on H100.

4.2 End-to-End Training Throughput

GPT-style models on 8× A100 80GB SXM4 GPUs:

ModelSeqNo FlashAttnFA1FA2Speedup (FA2 vs baseline)
GPT3-1.3B2k142 TF/s1891961.38×
GPT3-1.3B8k72 TF/s1702203.06×
GPT3-2.7B2k149 TF/s1892051.38×
GPT3-2.7B8k80 TF/s1752252.81×

The 8k context case shows the largest gain because at long sequences the batch size shrinks (memory constraints), reducing batch × heads parallelism — precisely the scenario where FA2’s sequence-length parallelism and higher occupancy make the biggest difference.

At 225 TFLOPs/s for GPT3-2.7B with 8k context:

MFU=225 TFLOPs/s312 TFLOPs/s=72%(model FLOPs utilization)(24)\text{MFU} = \frac{225 \text{ TFLOPs/s}}{312 \text{ TFLOPs/s}} = 72\% \quad \text{(model FLOPs utilization)} \tag{24}

This is exceptionally high for a real training workload — close to the hardware ceiling when all overheads (data loading, optimizer steps, gradient communication) are factored in.

Design Analysis: Why / Alternative / Boundary Conditions

Why Defer Normalization Rather Than Removing It Entirely?

The question: Could we avoid all per-step normalization by processing all blocks, accumulating raw unnormalized sums, and normalizing once at the very end?

The answer: No — numerical overflow prevents this. Without the online max subtraction, eSije^{S_{ij}} for large SijS_{ij} values overflows FP16/FP32. The online max trick (eSijme^{S_{ij} - m} always 1\leq 1) is essential for numerical stability. FA2 keeps the per-step max correction (em(j1)m(j)e^{m^{(j-1)}-m^{(j)}}, which is always 1\leq 1 by construction) but removes the per-step sum normalization (diag()1\text{diag}(\ell)^{-1}). The max correction is cheap (scalar multiply); the sum normalization is what FA2 eliminates.

Boundary: The max correction step mmax(m,rowmax(S(j)))m \leftarrow \max(m, \text{rowmax}(S^{(j)})) still executes at every inner loop iteration. For numerically flat inputs (all attention logits near-zero), the correction factor em(j1)m(j)1e^{m^{(j-1)}-m^{(j)}} \approx 1 and could be skipped, but this requires a runtime branch that might hurt performance from branch divergence.

Why Split-Q Instead of Splitting Along Another Dimension?

The question: What if you split along the head dimension at the warp level instead of the row dimension?

The answer: Head-dimension split within a thread block doesn’t make sense — the thread block itself handles one attention head for one batch element. You can’t further subdivide the head dimension within that block’s computation without creating partial GEMM tiles that are too small for efficient Tensor Core utilization.

The only meaningful within-block splits are rows (Q dimension) or columns (K/V dimension) of the attention tile. Split-K (column split) is natural for standard GEMM but fails for attention because of row-wise softmax coupling — warps need to communicate to aggregate mm and \ell statistics across columns. Split-Q (row split) gives each warp independent rows for which softmax is self-contained.

Alternative considered in the paper: An “inter-block split” for GQA/MQA where one thread block handles multiple Q heads sharing the same K/V. This adds another parallelism axis and could improve occupancy for models where H is very small. FA2 leaves this to future work.

Why Isn’t FA2 Used for All Sequence Lengths Uniformly?

Observation: The throughput numbers show FA2’s advantage is larger at longer sequences. At very short sequences (512 tokens), FA2 is ~2× faster than FA1, but the absolute throughput (132 vs 91 TFLOPs/s) still doesn’t reach peak because at short sequences, the overhead of kernel launch, block scheduling, and SRAM initialization is a larger fraction of total time.

Boundary condition: For sequences shorter than ~256 tokens, the standard attention implementation (PyTorch’s F.scaled_dot_product_attention with cuDNN backend) may match or beat FA2 because the O(N2)O(N^2) matrices are small enough to fit in L2 cache, eliminating the HBM bandwidth bottleneck that FA2 is designed to address.

Relationship to FA1, FA3, and xformers

FA1 to FA2

Both use identical mathematical algorithms for tiled online softmax attention. FA2 changes only the implementation:

  • Same: Tiling structure, online softmax recurrence, HBM IO complexity O(N2d/M)O(N^2 d / M)
  • Different: Deferred 1/1/\ell normalization; sequence-length parallelism; split-Q warp layout

FA1 and FA2 produce identical numerical outputs (up to floating-point rounding).

FA2 to FA3 (Shah et al., ICLR 2025)

FA3 targets Hopper (H100/H200) GPUs, adding:

  • TMA (Tensor Memory Accelerator): Hardware-assisted async tile loads, overlapping memory and compute
  • Warp specialization: Some warps dedicated to loading (producers), others to computing (consumers)
  • FP8 support: 2× throughput vs FP16
  • WGMMA instructions: 4th-generation Tensor Core instructions for better utilization

FA3 reaches ~75% of H100 FP8 theoretical peak (~1,300+ TFLOPs/s effective), a 2× improvement over FA2 on H100.

FA2 vs xformers

Facebook’s xformers library (cutlass-based) was competitive with FA1 but consistently lags behind FA2 by ~10-20% in the benchmarks. The gap is due to xformers not adopting the split-Q warp layout.

Arithmetic Intensity and the Roofline Model

A principled way to understand where attention sits on the GPU’s performance landscape is the roofline model. For a GPU with compute throughput π\pi (FLOPs/s) and memory bandwidth β\beta (bytes/s), the maximum achievable performance for an operation is:

P=min ⁣(π,  βI)(25)P = \min\!\left(\pi,\; \beta \cdot I\right) \tag{25}

where II is the arithmetic intensity (FLOPs per byte of memory transferred). The “roofline” has two regimes:

  • Memory-bound (I<π/βI < \pi/\beta): performance limited by bandwidth β\beta
  • Compute-bound (I>π/βI > \pi/\beta): performance limited by peak FLOPs π\pi

For A100: π=312 TF/s\pi = 312 \text{ TF/s} (FP16 matmul), β=2 TB/s\beta = 2 \text{ TB/s}, so the ridge point is:

Iridge=312 TF/s2 TB/s=156 FLOPs/byte(26)I_\text{ridge} = \frac{312 \text{ TF/s}}{2 \text{ TB/s}} = 156 \text{ FLOPs/byte} \tag{26}

Standard attention arithmetic intensity:

  • FLOPs: 4N2d4N^2 d (two GEMMs for QKQK^\top and PVPV)
  • Memory: O(N2)O(N^2) (reading/writing SS and PP matrices to HBM)
  • Intensity: 4N2d/(4N2)=d4N^2 d / (4N^2) = d FLOPs/byte (for FP16, with d=128d=128: 128/2 = 64 FLOPs/byte)

At d=128d=128, standard attention has intensity ~64 FLOPs/byte — below the ridge point of 156. Standard attention is memory-bound.

FlashAttention arithmetic intensity (FA1/FA2):

  • FLOPs: same O(N2d)O(N^2 d)
  • Memory: O(Nd)O(Nd) HBM accesses (Q, K, V, O tiles, no N×NN \times N intermediates)
  • Effective intensity: 4N2d4Nd=N\frac{4N^2 d}{4Nd} = N FLOPs/byte

For N=8192N=8192: intensity = 4096 FLOPs/byte — far above the ridge point. FA1/FA2 are compute-bound, not memory-bound. This is why they can approach peak Tensor Core throughput.

graph LR
    subgraph roofline["Roofline Model (A100 FP16)"]
        direction TB
        MEM["Memory-bound region\n(slope = bandwidth 2 TB/s)\nStandard attention lives here\n(intensity ~64 FLOPs/byte)"]
        RIDGE["Ridge point\n(156 FLOPs/byte)\nTransition point"]
        COMPUTE["Compute-bound region\n(ceiling = 312 TF/s)\nFA1/FA2 live here\n(intensity = N ≫ 156)"]
    end
    MEM -->|"FA1/FA2 shift\nattention here!"| COMPUTE
    style MEM fill:#ff9999
    style COMPUTE fill:#99ff99
    style RIDGE fill:#ffff99

Figure 7: Roofline model for A100. Standard attention is memory-bound (intensity < 156 FLOPs/byte). FA1/FA2 eliminate the N×N HBM traffic, making attention compute-bound. The remaining gap from 312 TF/s peak is due to non-matmul FLOPs (FA2’s fix) and warp inefficiency (FA2’s split-Q fix).

The residual gap between FA2 and the 312 TF/s theoretical peak is not memory bandwidth anymore — it’s the ~27-50% overhead from non-matmul ops (exp, max, reduction) and scheduling inefficiency. FA2 reduces this gap by 2× vs FA1 by attacking exactly these non-matmul operations and the warp layout.

Practical Implications for LLM Training and Inference

Context Length Scaling

The most direct practical impact: FlashAttention-2 makes long-context training tractable.

Before FA2, training a GPT-style model at 8k context cost ~3× more compute per token than at 2k context (due to the O(N2)O(N^2) attention FLOP growth). With FA2’s sequence-length parallelism fully utilizing SM capacity:

Context LengthBaseline GPU time (FLOPs)FA2 real throughputFA2 effective cost
2k~196 TF/s
8k16× (attention only)~220 TF/s (+12%)~1.1× per-token cost
16k64× (attention only)~224 TF/ssimilar

The reason FA2 scales better than naively expected: at longer sequences, the batch size per GPU must decrease (memory pressure), and FA2’s sequence-length parallelism fills the resulting SM vacancy. The attention kernel doesn’t slow down proportional to N2N^2.

Inference and KV Cache Impact

For autoregressive inference, the “decode” step attends one new token’s query against the full KV cache of length NN:

  • Query shape: 1×d1 \times d (single new token)
  • KV cache: N×dN \times d

With N=1N=1 in the QQ dimension, standard metrics differ: this is not matrix-matrix multiplication but matrix-vector multiplication. FA2’s tiling and Tensor Core utilization matter less here. For decode-phase inference:

  • FA2 helps in the prefill phase (processing the prompt, NN tokens in the query)
  • For decode phase, the bottleneck shifts to KV cache bandwidth — KV cache compression techniques (MQA, GQA, quantized KV) are more important

FA2 directly supports MQA and GQA (Section 6), enabling the smaller KV caches these variants provide while maintaining full-speed attention computation.

Integration with PyTorch

As of PyTorch 2.0+, torch.nn.functional.scaled_dot_product_attention automatically dispatches to FA2 when:

  • Running on CUDA (Ampere+)
  • Input is FP16 or BF16
  • Sequence length is not too short (>128 tokens)

This makes FA2 effectively the default for all modern LLM training in PyTorch without any code changes. Users who need the latest optimizations (H100 FA3, custom MQA layouts, FP8) should use the flash-attn package directly.

Implementation Details and Block Size Tuning

Choosing Block Sizes BrB_r and BcB_c

The block sizes BrB_r (Q rows per tile) and BcB_c (K/V columns per tile) are critical hyperparameters that must be tuned per GPU and head dimension. The constraints are:

  1. SRAM capacity: All three tiles must simultaneously fit:
Brd2+2Bcd2+BrBc2MSRAM(27)B_r \cdot d \cdot 2 + 2 \cdot B_c \cdot d \cdot 2 + B_r \cdot B_c \cdot 2 \leq M_\text{SRAM} \tag{27}

(Factor 2 for FP16; Q tile + K tile + V tile + S tile)

  1. Tensor Core alignment: Both dimensions of each matmul tile should be multiples of 16 (or 8 for certain precisions).

  2. Register pressure: Larger BrB_r means more rows of O~\tilde{O}, mm, \ell to keep in registers. If registers spill to local memory (L1 cache), performance degrades severely.

  3. Occupancy vs tile size tradeoff: Larger blocks → higher arithmetic intensity within the block (fewer HBM accesses per FLOP) but larger SRAM footprint → fewer thread blocks can reside on the SM simultaneously (lower occupancy).

For A100 with d=128d=128 and 192 KB SRAM per SM, typical optimal values: Br=Bc=64B_r = B_c = 64 or Br=64,Bc=128B_r = 64, B_c = 128. FA2 uses manual lookup tables indexed by (head_dim, GPU architecture). FA3 and the Triton implementation add auto-tuning.

Logsumexp and Numerical Stability

An important subtlety: the logsumexp L=m+log()L = m + \log(\ell) stored by FA2 is computed as:

Li=mi(Tc)+log(i(Tc))(28)L_i = m_i^{(T_c)} + \log(\ell_i^{(T_c)}) \tag{28}

where mi(Tc)m_i^{(T_c)} is the final running max for row ii, and i(Tc)\ell_i^{(T_c)} is the final running sum-of-exponentials. This is numerically stable because:

  • =jeSijm\ell = \sum_j e^{S_{ij} - m} is a sum of terms all in (0,1](0, 1], so (0,Tc]\ell \in (0, T_c]
  • log()(,logTc]\log(\ell) \in (-\infty, \log T_c] — bounded, no overflow
  • L=m+logL = m + \log\ell is the standard log-sum-exp, matching what you’d compute in FP64

In the backward pass, recomputing Pij=eSijLiP_{ij} = e^{S_{ij} - L_i} never overflows because SijLi=Sijmilogi0S_{ij} - L_i = S_{ij} - m_i - \log\ell_i \leq 0 (since mi=maxkSikm_i = \max_k S_{ik}).

FA2 Algorithm vs FA1: Complete Comparison

PropertyFlashAttention-1FlashAttention-2
HBM complexityO(N2d/M)O(N^2 d / M)Same — unchanged
Per-step \ell normalizeYes, at every inner stepNo — deferred to outer-step end
Backward statisticsStore (m,)(m, \ell) — 2 scalars/rowStore L=m+logL = m + \log\ell — 1 scalar/row
Sequence-length parallelismNo — outer loop sequentialYes — each row block is a thread block
Warp layout (forward)Split-K: warps own K/V column slicesSplit-Q: warps own Q row slices
Inter-warp communication (fwd)Required — shared memory write/syncNone needed
Backward dQ updateSequential within thread blockAtomic adds across thread blocks
MQA/GQA supportNo (external handling)Yes — native index stride manipulation
Causal mask optimizationSkip fully-masked blocksSame — unchanged
Throughput on A100 (fwd+bwd)25-40% peak50-73% peak
End-to-end GPT-2.7B 8k175 TF/s225 TF/s

The key insight from this comparison: FA2 makes no changes to the outer mathematical structure — same tiling, same online softmax recurrence, same HBM complexity. Every improvement is in the implementation mapping from algorithm to GPU hardware.

Summary

FlashAttention-2 is a case study in precise systems engineering. Tri Dao identified three specific bottlenecks in FA1 — each with a measurable cost — and fixed each independently with minimal code complexity:

  1. Deferred \ell normalization: removes Tc1T_c - 1 extra non-matmul diagonal multiplications per row block
  2. Sequence-length parallelism: fills idle SMs when B×H<108B \times H < 108, multiplying occupancy by up to Tr=128×T_r = 128\times
  3. Split-Q warp layout: eliminates inter-warp shared-memory writes and sync barriers in the forward pass

The combined effect: 2× speedup over FA1, 50-73% of A100 peak, 225 TFLOPs/s for GPT-2.7B at 8k context. The implication: training 16k-context models costs the same as 8k-context training used to.

For practitioners: FA2 is available as pip install flash-attn and is the default backend for PyTorch’s torch.nn.functional.scaled_dot_product_attention on CUDA. For H100 GPUs, FA3 is worth switching to. For CPU or non-CUDA backends, standard attention or xformers is the right choice.

The mental model: GPU attention performance depends on three separate dimensions — IO traffic to HBM (fixed by FA1), compute utilization / occupancy (fixed by FA2’s parallelism), and warp-level efficiency (fixed by FA2’s split-Q). FA1 solved dimension one; FA2 solves dimensions two and three.

Reproducibility Notes

The paper reports benchmark numbers on A100 80GB SXM4, with CUDA 11.8, cuDNN 8.7, PyTorch 2.0. Key implementation variables:

  • Block sizes used: head dim 64 → Br=Bc=64B_r=B_c=64; head dim 128 → Br=64,Bc=64B_r=64, B_c=64 (some configs use 128)
  • Reference implementation: github.com/Dao-AILab/flash-attention (Apache 2.0 license)
  • PyTorch integration: Available via pip install flash-attn==2.x.x; also merged into PyTorch as torch.backends.cuda.flash_sdp_enabled() path
  • FLOP counting convention: The paper counts attention FLOPs as 4N2d4N^2 d (forward) and 10N2d10N^2 d (backward, 2.5× forward). Note: causal attention divides forward FLOPs by 2 in the per-second efficiency calculations but NOT in the end-to-end model FLOPs (follows Megatron-LM convention for consistency)
  • MFU calculation: MFU=measured throughput (TFLOPs/s)peak TFLOPs/s of GPU\text{MFU} = \frac{\text{measured throughput (TFLOPs/s)}}{\text{peak TFLOPs/s of GPU}}; for GPT-2.7B at 8k context, FA2 reports 72% MFU on A100 (225/312)

The key result — ~2× speedup over FA1 — is robust across head dimensions, sequence lengths, and batch sizes. The absolute throughput numbers (TFLOPs/s) are sensitive to the specific A100 variant (SXM4 vs PCIe), driver version, and whether the GPU is at thermal throttle. Expect ±5-10% variation in practice.