笔记日期: 2026-05-24 笔记作者: Zhongzhu Zhou 论文标题: FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning 作者: Tri Dao(普林斯顿大学 / 斯坦福大学) arXiv: 2307.08691 状态 / Venue: arXiv 预印本(2023 年 7 月);ICLR 2024
一句话总结
FlashAttention-2 不是一个新算法,而是对 FlashAttention-1 实现层面的三处精准优化:减少非矩阵乘操作数、增加序列维度并行、改变 warp 内工作划分方式,把 A100 上的注意力吞吐量从 FA1 的 25-40% 峰值提升到 50-73% 峰值,实现了约 2× 加速,并使 GPT-2.7B 的 8k 上下文训练速度达到 225 TFLOPs/s(72% MFU)。
前置知识
1. GPU 内存层次
理解 FlashAttention 系列的起点是 GPU 的两级内存结构:
高带宽内存(HBM): GPU 的主 DRAM。A100-80GB 有 80 GB 容量,带宽约 2 TB/s。所有 PyTorch 张量默认存放在这里。
片上 SRAM(共享内存): 每个流式多处理器(SM)有独立的 SRAM。A100 每个 SM 有 192 KB,共 108 个 SM,总计约 20 MB SRAM。SRAM 带宽约 19 TB/s——大约是 HBM 的 10 倍。
代价是 SRAM 极小(每 SM 只有 192 KB),且是易失存储。
FlashAttention 系列的核心设计原则: 最小化对 HBM 的读写次数,尽量在 SRAM 内完成计算。每一次不必要的 HBM 访问都比 SRAM 贵 10 倍。
graph TD
CPU["主机 CPU"]
subgraph A100["A100 GPU(108个SM)"]
HBM["HBM:80 GB\n带宽约 2 TB/s\n(PyTorch 张量住这里)"]
subgraph SM["单个 SM(×108)"]
SRAM["片上 SRAM:192 KB\n带宽约 19 TB/s\n(10倍于HBM!)"]
TC["Tensor Core\nFP16矩阵乘:312 TF/s\n非矩阵乘FP32:19.5 TF/s\n(差16倍!)"]
end
end
CPU -->|"PCIe ~64 GB/s"| HBM
HBM -->|"瓶颈!减少往返次数"| SRAM
SRAM --> TC
style HBM fill:#ff9999
style SRAM fill:#99ff99
style TC fill:#9999ff
图 1:A100 GPU 内存层次。HBM 是瓶颈——比 SRAM 慢 10 倍。FlashAttention 的设计核心是把注意力计算融合成单次 HBM 遍历(tiling),在 SRAM 内完成中间计算。
2. GPU 执行模型:线程、Warp、线程块
GPU 的线程组织是理解 FA2 第三项优化(split-Q vs split-K)的关键:
- 线程(Thread): 最小执行单元,拥有寄存器。
- Warp: 32 个线程的 SIMD 集合,在同一 SM 上锁步执行。Warp 内线程可以通过 shuffle 指令直接交换寄存器数据(极快,无需同步屏障)。
- 线程块(Thread Block / CTA): 多个 Warp(通常 4-8 个)共享同一块 SRAM 分配区。Warp 间通信只能通过写 SRAM →
__syncthreads()→ 读 SRAM,需要显式同步屏障。 - 网格(Grid): 所有线程块。GPU 调度器把每个线程块分配给一个空闲 SM 执行。
占用率(Occupancy): 实际使用的 SM 资源(Warp 槽位、寄存器、SRAM)与最大可用资源的比值。占用率低意味着 SM 空转,硬件利用率不足。FA2 的序列维度并行改进(第三节)主要解决占用率不足问题。
graph TD
GRID["GPU 核函数启动(线程块网格)"]
GRID --> B0["线程块 0\n分配到 SM 0"]
GRID --> B1["线程块 1\n分配到 SM 1"]
GRID --> BN["线程块 N\n分配到 SM K"]
B0 --> W0["Warp 0(32线程)\n寄存器shuffle: 极快"]
B0 --> W1["Warp 1(32线程)"]
B0 --> W2["Warp 2(32线程)"]
B0 --> W3["Warp 3(32线程)"]
W0 -->|"写/读 + syncthreads()\n有代价!"| SMEM["SRAM(192 KB)\n所有Warp共享"]
W1 --> SMEM
W2 --> SMEM
W3 --> SMEM
style SMEM fill:#ffff99
图 2:GPU 线程层次。Warp 间通信必须经过 SRAM,需要显式同步屏障。FA2 的 split-Q 策略使得前向传播中 Warp 间完全不需要通信。
3. 标准注意力的 O(N²) 问题
设 Q、K、V 的形状均为 ( 为序列长度, 为头维度),标准注意力计算:
问题在于中间矩阵 和 均为 。以 (8k 上下文)、FP16 为例:
对于 40 头、batch size 4 的模型,注意力中间结果就需要 。
更糟的是,标准实现需要 3 次 HBM 往返:① 写 ;② 读 、写 ;③ 读 、写 。这是典型的 内存带宽瓶颈(memory-bound) 操作。
4. FlashAttention-1 回顾:IO 感知 + 在线 Softmax
FlashAttention-1(Dao et al., NeurIPS 2022)的核心思路:
分块(Tiling): 把 Q 按行分成大小为 的块,K/V 按列分成大小为 的块,逐块加载到 SRAM,避免将完整的 矩阵写入 HBM。
在线 Softmax(Online Softmax): Softmax 是按行耦合的——需要全行的 max 和 sum 才能归一化。在线 Softmax 维护一对运行统计量 (行内 running max 和 running 指数和),每处理一个 K/V 分块就更新一次:
这是 精确 的 Softmax——无任何近似。FA1 把 HBM 读写复杂度从 降至 ( 为 SRAM 大小)。
FA1 剩余的问题: 即使 IO 最优,FA1 在 A100 上仍只达到 25-40% 的理论峰值吞吐。对比之下,cuBLAS 的 GEMM 能达到 80-90%。差距来自三个实现层面的低效。
非矩阵乘操作的性能差距
A100 的计算单元存在显著的 不对称性:
- FP16/BF16 Tensor Core(矩阵乘): 312 TFLOPs/s
- FP32 标量单元(非矩阵乘): 19.5 TFLOPs/s
FA1 的内层循环中,每处理一个 K/V 分块就需要执行:
max(m_old, rowmax(S_ij))— 标量 maxexp(S_ij - m_new)— 逐元素 expdiag(ℓ)^{-1}对输出 归一化 — 对角矩阵乘法
这些全部跑在慢速标量单元上,与 Tensor Core 争用时钟周期。序列越长、K/V 分块数 越多,这些操作的总代价越大。
flowchart TB
subgraph "FA1 内层循环(每个 K/V 分块)"
A["从HBM加载 K_j, V_j"]
B["矩阵乘:S = Q_i × K_j^T\nTensor Core,312 TF/s"]
C["非矩阵乘(慢!)\nm_new = max(m, rowmax(S))\n标量max,19.5 TF/s"]
D["非矩阵乘(慢!)\nP̃ = exp(S - m_new)\nℓ_new = exp(...) × ℓ + rowsum(P̃)"]
E["混合:归一化 O\ndiag(ℓ)^{-1} × [...]\n对角乘法(非矩阵乘)"]
F["更新 m ← m_new; ℓ ← ℓ_new"]
A --> B --> C --> D --> E --> F
end
style C fill:#ff9999
style D fill:#ff9999
style E fill:#ffcccc
图 3:FA1 内层循环中的非矩阵乘操作(红色),每个 K/V 分块执行一次完整的 diag(ℓ)^{-1} 归一化。FA2 将这一归一化推迟到外层循环结束后执行一次。
FA2 的三项核心改进
改进一:减少非矩阵乘操作(算法层面调整)
1a. 延迟归一化(前向传播)
FA1 在内层每步都对输出 用 归一化,共执行 次。FA2 的关键发现:只需要在外层循环结束后统一做一次归一化。
FA1 的每步更新(含 归一化):
FA2 的每步更新(维护未归一化的 ):
内层循环只做 max 位移校正(),不做 归一化。结束后:
正确性验证(双分块情形): 令初始 ,
处理第二个分块():
最终归一化:
这与标准 Softmax 注意力结果完全一致。关键在于:max 位移校正确保分子各项相对大小正确,最终 提供正确的分母。
节省的 FLOP: FA1 对 个分块每次都执行 ,FA2 只做一次。对于 、,:FA2 消除了 127 次每步 归一化操作,每次操作的代价是 维标量除法(运行在 19.5 TF/s 的慢速单元上)。
1b. Logsumexp 压缩(后向传播)
FA1 为后向传播存储两个标量 ,FA2 把它们压缩为一个:
这就是逐行的 log-sum-exp。在后向传播中,从 恢复注意力概率只需一步:
不需要先从 重建,减少了后向传播中的非矩阵乘指令数,同时把每行存储的标量数从 2 减少到 1,节省了 个 FP16 的 HBM 写入量(不大但有益)。
FA2 前向传播算法(Algorithm 1):
算法 1:FlashAttention-2 前向传播
──────────────────────────────────────────────────────────────────────────
输入:Q, K, V ∈ ℝ^{N×d}(存于 HBM);分块大小 B_r, B_c
输出:O ∈ ℝ^{N×d}(存于 HBM);L ∈ ℝ^N(logsumexp,存于 HBM)
──────────────────────────────────────────────────────────────────────────
1. Q 划分为 T_r = ⌈N/B_r⌉ 个行块 Q_1 … Q_{T_r}(每个 B_r×d)
K, V 划分为 T_c = ⌈N/B_c⌉ 个列块 K_1 … K_{T_c}, V_1 … V_{T_c}
2. for i = 1 … T_r do ← 外层循环 [FA2中每个行块是独立线程块]
3. HBM→SRAM: 加载 Q_i
4. 初始化 Õ_i = 0, ℓ_i = 0, m_i = -∞
5. for j = 1 … T_c do ← 内层循环
6. HBM→SRAM: 加载 K_j, V_j
7. S_ij = Q_i K_j^T ← 矩阵乘(Tensor Core)
8. m_i_new = max(m_i, rowmax(S_ij)) ← 非矩阵乘
9. P̃_ij = exp(S_ij − m_i_new) ← 非矩阵乘
10. ℓ_i_new = exp(m_i−m_i_new)·ℓ_i + rowsum(P̃_ij) ← 非矩阵乘
11. Õ_i = diag(exp(m_i−m_i_new))^{-1} Õ_i + P̃_ij V_j
← max校正(非矩阵乘) + 矩阵乘
── FA2 关键:此处无 diag(ℓ)^{-1} ──
12. m_i ← m_i_new; ℓ_i ← ℓ_i_new
13. end for
14. O_i = diag(ℓ_i)^{-1} Õ_i ← 每个外层步仅归一化一次!
15. L_i = m_i + log(ℓ_i) ← 存logsumexp供后向传播使用
16. SRAM→HBM: 写回 O_i, L_i
17. end for
──────────────────────────────────────────────────────────────────────────
与 FA1 相比,第 14 行的 从内层循环移到了外层循环,一次完成。
FA2 后向传播算法(Algorithm 2):
算法 2:FlashAttention-2 后向传播
──────────────────────────────────────────────────────────────────────────
输入:Q, K, V, O, dO ∈ ℝ^{N×d}(HBM);L ∈ ℝ^N(HBM)
输出:dQ, dK, dV ∈ ℝ^{N×d}(HBM)
──────────────────────────────────────────────────────────────────────────
预计算:D = rowsum(dO ◦ O) ∈ ℝ^N
(Softmax 雅可比的缩放项:D_i = ∑_k dO_ik·O_ik)
1. for j = 1 … T_c do ← 外层循环(对每个 K/V 列块)
2. 加载 K_j, V_j : HBM→SRAM
3. 初始化 dK_j = 0, dV_j = 0
4. for i = 1 … T_r do ← 内层循环
5. 加载 Q_i, O_i, dO_i, L_i, D_i : HBM→SRAM
6. S_ij = Q_i K_j^T ← 矩阵乘
7. P_ij = exp(S_ij − L_i) ← 用 logsumexp 重算注意力概率
8. dV_j += P_ij^T · dO_i ← 矩阵乘
9. dP_ij = dO_i · V_j^T ← 矩阵乘
10. dS_ij = P_ij ◦ (dP_ij − D_i) ← Softmax 雅可比(非矩阵乘)
11. dQ_i += dS_ij · K_j ← 矩阵乘(原子加回 HBM)
12. dK_j += dS_ij^T · Q_i ← 矩阵乘
13. end for
14. 写回 dK_j, dV_j → HBM
15. end for
──────────────────────────────────────────────────────────────────────────
第 7 行展示了 logsumexp 的作用:直接从存储的 重算 ,无需从 HBM 读取完整的 矩阵 。第 11 行用原子加(atomic add)是因为 接受来自所有列块 的贡献,需要跨线程块安全累加。
Softmax 雅可比第 10 行利用了以下恒等式:若 ,则:
其中 即预计算的行缩放值。这确保了后向传播也是精确的。
改进二:序列维度并行
FA1 沿两个维度并行:Batch 维度和 Head 维度,共启动 个线程块。对于长上下文推理场景(Batch=1, H=32):
FA2 新增序列维度并行:外层循环对每个 Q 行块 是独立的,FA2 把每个行块分配给一个独立线程块,同时在不同 SM 上并行执行。
、:,线程块数增加 128 倍,SM 占用率大幅提升。
graph LR
subgraph FA1_para["FA1:2维并行"]
direction TB
T1["线程块0 = Head0, Batch0\n处理全部128个行块迭代\n(串行)"]
T2["线程块1 = Head1, Batch0"]
TN["线程块31 = Head31\n共32个线程块 < 108 SM"]
end
subgraph FA2_para["FA2:3维并行"]
direction TB
T2_0["线程块0 = Head0, Batch0, 行块0"]
T2_1["线程块1 = Head0, Batch0, 行块1"]
T2_N["线程块4095 = Head31, 行块127\n共32×128=4096个线程块"]
end
FA1_para -->|"长上下文:B×H小\n大量SM闲置"| LOW["低占用率"]
FA2_para -->|"T_r倍更多线程块\n填满全部SM"| HIGH["高占用率"]
style LOW fill:#ff9999
style HIGH fill:#99ff99
图 4:FA2 新增序列维度并行。FA1 的 32 个线程块使 A100 上 76 个 SM 空转。FA2 额外引入行块维度,创建 4096 个线程块,充分填满所有 SM。
前向传播: 外层循环各行块完全独立,无线程块间通信,理想并行。
后向传播: 外层循环迭代 K/V 列块。每个线程块处理一个 列块,遍历全部行块计算 和 。但 需要从全部列块累加,FA2 使用原子加处理跨线程块的 更新(式 11)。
这一思路最早由 OpenAI 的 Phil Tillet 在 Triton 的 FlashAttention 实现中提出,FA2 在 CUDA 内核中将其规范化。
改进三:Warp 内工作划分(Split-Q vs Split-K)
即使线程块级的并行度已经最优,线程块内的 4-8 个 Warp 如何分工仍然影响性能。FA2 的第三项改进就是改变这里的划分策略。
FA1 的 Split-K 策略
FA1 把 K/V 的头维度 分配给各个 Warp:
- Warp 0 处理 (K 的前 1/4 列)
- Warp 1 处理
- Warp 2、3 类似
每个 Warp 得到 的一个部分列切片,与对应的 切片相乘得到部分输出。最终合并各 Warp 的部分输出需要经过 SRAM:
- 每个 Warp 把中间结果写入 SRAM
- 全部 Warp 执行
__syncthreads()同步 - 汇总并写回
问题: Softmax 是按行耦合的——每个 Warp 只看到 的部分列,无法独立计算完整的 和 。必须先聚合各 Warp 的 切片,或者在 Warp 间共享 和 统计量,这又增加了额外的 SRAM 通信轮次。
FA2 的 Split-Q 策略
FA2 改为把 Q 的行维度 分配给各个 Warp:
- Warp 0 处理 Q 的第 0-15 行,独立计算这些行的完整注意力和输出
- Warp 1 处理 Q 的第 16-31 行
- Warp 2、3 类似
所有 Warp 共享 K 和 V 的同一块(由线程块统一加载到 SRAM,只读),但每个 Warp 独立维护自己所负责行的 。
关键: Softmax 按行计算。每个 Warp 拥有完整的若干行,可以独立执行在线 Softmax——不需要和其他 Warp 交换任何数据。零 Warp 间通信,无需 SRAM 写入,无需同步屏障。
graph TD
subgraph fa1["FA1:Split-K"]
direction LR
subgraph W0K["Warp 0"]
q0["Q_i(完整 B_r 行)"]
k0["K_j 列切片 [0:d/4]"]
p0["部分 QK^T"]
end
subgraph W1K["Warp 1"]
q1["Q_i(同)"]
k1["K_j [d/4:d/2]"]
p1["部分 QK^T"]
end
SMEM_K["SRAM: 各Warp写部分结果\n+ syncthreads + 求和\n(瓶颈!)"]
O_K["最终 O_i"]
p0 --> SMEM_K
p1 --> SMEM_K
SMEM_K --> O_K
end
subgraph fa2["FA2:Split-Q"]
direction LR
subgraph W0Q["Warp 0"]
q0q["Q_i 行切片 [0:B_r/4]"]
kv0["K_j, V_j(完整,SRAM 只读)"]
out0["完整输出 O_i[0:B_r/4]"]
end
subgraph W1Q["Warp 1"]
q1q["Q_i [B_r/4:B_r/2]"]
kv1["K_j, V_j(同)"]
out1["完整输出 O_i[B_r/4:B_r/2]"]
end
O_Q["最终 O_i\n(无需同步!)"]
kv0 -.-> q0q
kv1 -.-> q1q
out0 --> O_Q
out1 --> O_Q
end
style SMEM_K fill:#ff9999
style O_Q fill:#99ff99
图 5:FA1 的 Split-K(上)需要所有 Warp 把部分结果写入 SRAM 并同步。FA2 的 Split-Q(下)每个 Warp 独立处理完整行,K/V 由 SRAM 只读共享,完全无 Warp 间通信。
为什么 Split-Q 能奏效: Softmax 沿行方向耦合,各行之间是独立的。只要每个 Warp 负责完整的若干行(不跨 Warp 分割行),它就能独立完成在线 Softmax——、、 全部存在该 Warp 的寄存器中,不需要与其他 Warp 共享。
设计边界: Split-Q 的代价是:所有 4 个 Warp 需要同时访问 K 和 V 的同一个分块,这些分块必须完整地放入 SRAM。这对 (列分块大小)形成了上限约束。FA1 的 Split-K 可以用更大的 (每个 Warp 只访问 K 的一部分),但 FA2 的收益远超这一代价。
后向传播中 Softmax 雅可比的推导
后向传播算法中第 10 步的 dS_ij = P_ij ◦ (dP_ij − D_i) 涉及 Softmax 的雅可比矩阵。这里完整推导其来源,这是后向传播正确性的关键。
Softmax 的雅可比: 设 ,即 。对损失 对 求偏导:
计算 :
代入:
其中 即向量点积。写成矩阵形式:
在注意力后向传播中的对应: (一行注意力权重),(),(预先计算好)。
预计算 而非在内层循环中实时计算的好处: 只依赖 和 ,与 (列块索引)无关,可以一次性算出所有行的 向量,避免 次重复计算。
多 Query 注意力(MQA)和分组 Query 注意力(GQA)的支持
现代大模型推理中,KV 缓存常占据大量显存。减小 KV 缓存的主流方法:
多 Query 注意力(MQA)(Shazeer, 2019): 所有 Query 头共享同一组 K 和 V(只有 1 个 K/V 头)。KV 缓存缩小 倍( 为 Query 头数)。
分组 Query 注意力(GQA)(Ainslie et al., 2023): 组,每组的 Query 头共享一个 K/V 头。 退化为 MQA, 退化为标准多头注意力(MHA)。
FA2 通过张量步长(stride)技巧原生支持这两者,无需在内存中复制 K/V:第 个 Query 头属于第 组,FA2 在索引 K/V 时直接使用 ,而不是 。这是一个零内存开销的实现,不需要在内核外手动广播 K/V。
后向传播中的额外处理: 由于多个 Query 头共享同一个 K/V 头,梯度 和 需要在组内的所有 Query 头上求和。FA2 在后向传播核函数之后附加一个规约(reduction)操作完成这一求和。
因果掩码优化
自回归语言模型需要因果掩码:位置 的 query 只能 attend 到位置 。在分块框架下,FA2 利用两种情况:
情形一——完全被掩码的分块: 如果某个分块的所有列索引 都严格大于行索引 (即整块都在对角线以上),则所有条目经掩码后变为 ,Softmax 后贡献为零。FA2 直接跳过该分块的 GEMM 和 HBM 加载。
对于大序列,约一半的分块是完全被掩码的,跳过后比无掩码注意力快约 1.7-1.8×。
情形二——边界分块: 仅对角线分块(部分有效、部分掩码)需要显式应用因果掩码,每个行块只有 1 个这样的分块。对角线以下的分块全部无掩码,无需任何掩码逻辑。
graph LR
subgraph mask["N×N 因果注意力矩阵(分块视角)"]
ABOVE["上三角分块(j > i)\n全部掩码 → 跳过!\n约50%的分块"]
DIAG["对角线分块(j ≈ i)\n部分掩码 → 显式处理\n每行块仅1个"]
BELOW["下三角分块(j < i)\n全部有效 → 全速计算\n无掩码逻辑"]
end
ABOVE -->|"跳过约50%计算"| SPEEDUP["约1.7-1.8×\n速度提升"]
style ABOVE fill:#cccccc
style SPEEDUP fill:#99ff99
style DIAG fill:#ffff99
style BELOW fill:#ccffcc
图 6:因果掩码在 FA2 分块框架中的处理。上三角分块直接跳过,对角线分块施加掩码,下三角分块全速执行。这带来约 1.7× 相对于非因果注意力的速度提升。
算术强度与屋顶线模型
屋顶线模型(Roofline Model) 提供了理解注意力性能上界的分析框架。对于计算吞吐 (FLOPs/s)和内存带宽 (bytes/s)的硬件:
其中 是算术强度(每字节内存传输对应的 FLOPs 数)。A100 的”脊线点”:
标准注意力的算术强度:
- 计算量:
- 内存访问:(读写 S、P 矩阵)
- 算术强度:(FP16,128/2 = 64 FLOPs/byte)
64 FLOPs/byte < 156 FLOPs/byte:标准注意力是内存带宽瓶颈(memory-bound)。
FlashAttention 的算术强度:
- 计算量:仍为
- 内存访问:降至 (不写 矩阵)
- 算术强度:( 时为 4096 FLOPs/byte)
4096 ≫ 156:FA1/FA2 是计算瓶颈(compute-bound),能够接近峰值 Tensor Core 吞吐。
FA2 进一步提升了计算瓶颈下的效率:减少非矩阵乘 FLOP,让 Tensor Core 占用更高比例的总时间。
实验结果
注意力核吞吐量
所有测试在 A100 80GB SXM4 上进行。序列长度从 512 到 16k,批大小固定使总 token 数为 16k,隐层维度 2048。
前向+后向速度,无因果掩码,头维度 64(A100):
| 方法 | 512 | 1k | 2k | 4k | 8k | 16k |
|---|---|---|---|---|---|---|
| PyTorch 标准实现 | 68 | 91 | 98 | 95 | 83 | OOM |
| FlashAttention-1 | 91 | 90 | 104 | 98 | 92 | 76 |
| FA Triton | 102 | 92 | 110 | 110 | 108 | 100 |
| FlashAttention-2 | 132 | 162 | 171 | 176 | 175 | 173 |
单位:TFLOPs/s。A100 FP16 峰值 = 312 TFLOPs/s。
有因果掩码,头维度 128 的前向速度: FA2 最高达 224-227 TFLOPs/s,相当于 A100 峰值的 72-73%,已接近纯矩阵乘 cuBLAS 的效率(80-90%)。
FA2 在 16k 长序列上更明显:此时批大小受内存限制通常很小,FA2 的序列维度并行恰好填补了因批大小小导致的 SM 闲置问题。
端到端 GPT 训练吞吐量
GPT 模型在 8× A100 80GB SXM4 上训练结果:
| 模型 | 上下文长度 | 无 FA | FA1 | FA2 | FA2 vs 基线加速比 |
|---|---|---|---|---|---|
| GPT3-1.3B | 2k | 142 TF/s | 189 | 196 | 1.38× |
| GPT3-1.3B | 8k | 72 TF/s | 170 | 220 | 3.06× |
| GPT3-2.7B | 2k | 149 TF/s | 189 | 205 | 1.38× |
| GPT3-2.7B | 8k | 80 TF/s | 175 | 225 | 2.81× |
8k 上下文下加速比最大(3×),原因正是:长序列迫使批大小缩小,而 FA2 的序列维度并行解决了由此产生的 SM 低占用率问题。
72% 的模型 FLOPs 利用率对真实训练任务(含数据加载、优化器、梯度通信)而言非常高,已接近硬件上限。
前向传播与后向传播的 IO 复杂度分析
为了从理论上量化 FA2 相对于标准注意力的优势,我们分析每次前向传播的 HBM 读写字节数(以 FP16,即每元素 2 字节计)。
标准注意力 HBM 访问:
- 读 Q, K, V: 字节
- 写 : 字节
- 读 ,写 : 字节
- 读 P, V,写 O: 字节
- 总计: 字节( 时以 主导)
FlashAttention(FA1/FA2)HBM 访问:
- 读 Q: 字节(分 块,每外层步读一次)
- 每内层步读 K_j, V_j: 字节,共 步
- 每外层步写 O_i: 字节,共 步
- 写 L(logsumexp): 字节(可忽略)
- 总计:
由于 ,总 HBM 访问量为 ,与 SRAM 大小 成反比:
对于 (SRAM 能容纳足够大的 ),FA 的 HBM 流量大幅低于标准注意力。以 (故 )为例:
FA 将 HBM 流量降低约 8 倍——这就是为什么 FA1 即使在算法层面没有 FA2 的改进,也能相对标准注意力取得显著加速:它把注意力从 memory-bound 变成了 compute-bound。
设计分析:为什么、替代方案、边界条件
为什么不把所有分块处理完再一次性归一化?
问题: 能否先处理全部 K/V 分块,累积未归一化的和,最后做一次归一化?
答案: 不行——数值溢出。如果不做在线 max 减法, 对大值会溢出 FP16/FP32。在线 max 技巧(每步维护 ,始终 ≤1)是数值稳定性的必要条件。FA2 保留了每步的 max 校正(相对廉价,仅是一个 的标量乘法),只去掉了更昂贵的每步 除法。
边界: 越大(序列越长),节省越显著。对 512 token 的短序列(),仅省 7 次归一化,收益有限。对 16k token(),省 255 次,效果明显。
为什么 Split-Q 可以无 Warp 通信而 Split-K 不行?
核心差异在于 Softmax 的耦合方向:
-
Split-K: 每个 Warp 拥有 的部分列切片。Softmax 按行计算,需要完整行的所有列才能求 max 和 sum。Warp 间必须通信来聚合统计量或部分输出。
-
Split-Q: 每个 Warp 拥有 的部分行。每行的 Softmax 完全包含在单个 Warp 内。 均为该 Warp 私有的寄存器,无需共享。
边界: Split-Q 要求 K 和 V 的分块能同时放入 SRAM 并被所有 Warp 访问,对 有上限。FA1 的 Split-K 每个 Warp 只需 K 的 1/4,可以使用更大的 。但 FA2 消除 Warp 通信的收益远超减小 的代价。
为什么序列维度并行对长序列更有效?
短序列(512-2k): 通常已足够大(例如 batch=32 × heads=32 = 1024 个线程块),108 个 SM 已经能被充分利用。序列维度并行增加的线程块带来的增量有限。
长序列(4k-16k): 内存限制迫使 batch 缩小(batch=1 到 4 很常见)。 可能只有 32-64 个线程块,远少于 108 个 SM。此时序列维度并行的 倍增加效果显著——这也解释了为什么 FA2 在端到端测试中 8k 上下文的加速比(3×)远大于 2k 上下文(1.38×)。
FA1 → FA2 → FA3 全景
| 性质 | FlashAttention-1 | FlashAttention-2 | FlashAttention-3(参考) |
|---|---|---|---|
| HBM IO 复杂度 | 同 | 同 | |
| 每步 归一化 | 每内层步执行 | 仅外层步一次 | 同 FA2 |
| 后向统计量 | 存 ,2个/行 | 存 ,1个/行 | 同 FA2 |
| 序列维度并行 | 无 | 有 | 有,更细粒度 |
| Warp 策略 | Split-K | Split-Q | Warp 专业化(producer/consumer) |
| Warp 间通信 | 需要 | 无 | 无 |
| 异步内存拷贝 | 无 | 无 | TMA(硬件异步加载) |
| FP8 支持 | 无 | 无 | 有 |
| A100 峰值占比(fwd+bwd) | 25-40% | 50-73% | N/A(主要针对H100) |
| H100 FP16 峰值(fwd+bwd) | ~30% | ~35%(直接移植) | ~75% |
FA3 的主要新增是 Hopper 专属特性(TMA、Warp 专业化、第四代 Tensor Core),对 Ampere GPU 意义不大。对于 A100 用户,FA2 是最优选择。
实现细节与分块大小调优
分块大小的约束
(Q 行分块)和 (K/V 列分块)需满足同时放入 SRAM 的约束:
(系数 2 对应 FP16,Q 块 + K 块 + V 块 + S 块)
同时还受到 Tensor Core 对齐要求(维度需为 16 的倍数)和寄存器压力约束。FA2 目前用手工调优的查找表(按头维度索引),典型值:(头维度 64), 或 128(头维度 128)。
Logsumexp 的数值稳定性
FA2 存储 ,其中:
- (始终正有界)
- (有界,无溢出)
后向传播中从 还原 :
因为 ,所以 ,不会溢出。这保证了后向传播与前向传播数值精度一致。
对 LLM 训练和推理的实际影响
上下文长度扩展的意义
FA2 最直接的影响是让长上下文训练变得经济可行。注意力的计算量随序列长度二次增长,但 FA2 通过序列维度并行使得实际 GPU 吞吐随上下文增加的衰减远小于理论预期:
| 上下文长度 | 注意力 FLOPs 增倍 | FA2 实测吞吐 | 每 token 等效成本 |
|---|---|---|---|
| 2k | 1× | ~196 TF/s | 1× |
| 8k | 16×(注意力部分) | ~220 TF/s(+12%) | 约1.1× per token |
| 16k | 64×(注意力部分) | ~224 TF/s | 类似 |
原因:更长序列迫使每卡批大小缩小(内存压力),FA2 的 倍更多线程块恰好填补了因批大小小导致的 SM 闲置。换言之,FA2 的效率提升越长越大,完美对冲了长上下文带来的内存限制。
对 KV 缓存的影响
在自回归推理的解码阶段(逐 token 生成),每步的 query 形状为 (只有一个新 token),需要 attend 到长度 的 KV 缓存。此时操作从矩阵-矩阵乘退化为向量-矩阵乘,Tensor Core 利用率自然下降——FA2 的优势主要体现在:
- 预填充阶段(Prefill): 处理完整 prompt( 个 token)时,FA2 的效率提升完全生效。
- MQA/GQA 支持: FA2 原生支持多 query 注意力(多个 Q 头共享一组 K/V),使推理时 KV 缓存大小可减少 (头数)倍,直接降低 GPU 内存需求和带宽压力。
与 PyTorch 的集成
PyTorch 2.0+ 中,torch.nn.functional.scaled_dot_product_attention 在以下条件下自动调用 FA2:
- CUDA(Ampere 及以上,即 A100/A10G/RTX 3090 等)
- 输入为 FP16 或 BF16
- 序列长度不太短(通常 >128)
这意味着绝大多数现代 LLM 训练无需任何代码改动即可获得 FA2 的性能提升。如需最新优化(H100 FA3、FP8、自定义融合 op),使用 flash-attn 包。
可复现性说明
论文测试环境:A100 80GB SXM4,CUDA 11.8,cuDNN 8.7,PyTorch 2.0。关键实现细节:
- 分块大小: 头维度 64 → ;头维度 128 → (部分配置为 128)
- 参考实现:
github.com/Dao-AILab/flash-attention(Apache 2.0 许可) - PyTorch 集成:
pip install flash-attn==2.x.x;也已合并进 PyTorch 核心 - FLOP 计算惯例: 注意力 FLOPs 按 (前向)和 (后向,为前向 2.5×)计算。因果注意力的 per-second 效率计算中将前向 FLOPs 除以 2,但端到端 MFU 计算遵从 Megatron-LM 惯例,不除以 2,以保持与文献一致性
- MFU 计算: ;GPT-2.7B 在 8k 上下文下:
约 2× 的加速提升在不同头维度、序列长度和批大小下均稳健。绝对 TFLOPs/s 数值对 A100 型号(SXM4 vs PCIe)、驱动版本和热节流状态敏感,实际测量期望 ±5-10% 的偏差。
总结
FlashAttention-2 是一篇关于 GPU 系统优化精确性的精彩示范。Tri Dao 诊断出 FA1 三个独立的性能瓶颈,并为每个瓶颈设计了最小化的修复方案:
- 延迟 归一化: 消除内层循环中每步的对角矩阵乘法,将 次降至 1 次
- 序列维度并行: 为外层行块引入独立线程块,对长上下文场景填满空转 SM,提升占用率
- Split-Q Warp 策略: 基于”Softmax 按行独立”的观察,消除 Warp 间 SRAM 通信和同步屏障
三项改进叠加,实现了 FA1 的 2× 加速,使 A100 上注意力核效率从 25-40% 提升到 50-73% 的理论峰值,与纯矩阵乘的效率差距缩小到 10-20 个百分点。端到端结果:GPT-2.7B 在 8k 上下文下训练吞吐达到 225 TFLOPs/s(72% MFU),比无 FA 基线快 2.8×,比 FA1 快 1.3×。
这篇论文的示范价值在于:它证明了”IO 最优”(FA1 实现)和”计算最优”(FA2 进一步优化)是两个独立的问题,不能混为一谈。FA1 消除了 HBM 流量,使操作从 memory-bound 变成 compute-bound;FA2 则在 compute-bound 的假设下,通过减少非矩阵乘 FLOP、提高占用率、优化 Warp 通信,进一步接近 Tensor Core 的峰值利用率。这套分析框架——先找到瓶颈类型,再有针对性地解决——是 GPU 系统优化的通用方法论。
对实践者的建议: PyTorch 2.0+ 已将 FA2 作为 CUDA 上 scaled_dot_product_attention 的默认后端,无需代码改动即可受益。如需最新优化(H100 FA3、FP8、自定义 MQA/GQA),使用 flash-attn 库的最新版本。
关于 H100: 将 FA2 的同一代码直接移植到 H100 SXM5 上(不做任何 H100 专项优化),可达 294-338 TFLOPs/s,约为 H100 FP16 峰值(989 TFLOPs/s)的 30-34%。利用 H100 专属特性(TMA、Warp 专业化、FP8)的 FA3 能达到约 75% 峰值。因此对于 H100/H200 用户,FA3 是更好的选择;对于 A100/A10G/RTX 系列,FA2 是当前最优解。
(注:FA2 是 Apache 2.0 开源许可,代码地址:github.com/Dao-AILab/flash-attention)
思维框架:GPU 注意力性能受三个独立维度制约——HBM 带宽(FA1 解决)、SM 占用率(FA2 的序列并行解决)、Warp 效率(FA2 的 Split-Q 解决)。理解这三个维度是分析任何注意力优化工作的基础框架。
读完这篇论文,我最大的收获不是具体的技术手段,而是这套分层诊断的思路:先用屋顶线模型判断瓶颈类型(memory-bound 还是 compute-bound),再分析 compute-bound 内部的次级瓶颈(非矩阵乘 FLOP 占比、SM 占用率、Warp 通信),最后对每个瓶颈设计独立的、可量化的修复方案。这种工程文化在国内的 AI infra 团队中比较少见,值得借鉴。
从影响力的角度看,FlashAttention 系列是近年来对实际 LLM 训练和推理贡献最大的少数几篇论文之一——不是通过提出新模型或新算法,而是通过让已有算法在现有硬件上运行得更快。这提醒我们,系统级优化与算法创新同等重要,在资源受限的环境下甚至更为关键。
FA1 vs FA2 完整对照表
| 性质 | FlashAttention-1 | FlashAttention-2 |
|---|---|---|
| HBM IO 复杂度(前向) | 同,不变 | |
| 每内层步 归一化 | ✓ 每 K/V 分块执行 | ✗ 仅外层步结束执行一次 |
| 后向统计量存储 | — 2 个标量/行 | — 1 个标量/行 |
| 序列维度并行 | ✗ 外层循环串行 | ✓ 每行块独立线程块 |
| 线程块总数 | ||
| Warp 工作划分(前向) | Split-K:按 K/V 列维分割 | Split-Q:按 Q 行维分割 |
| Warp 间通信(前向) | ✓ 需要 SRAM 写 + syncthreads | ✗ 完全无需 |
| 后向 dQ 更新 | 线程块内串行 | 跨线程块原子加 |
| MQA/GQA 原生支持 | ✗ | ✓ 步长技巧 |
| 因果掩码优化 | ✓ 跳过全掩码分块 | ✓ 同,不变 |
| A100 峰值利用率(fwd+bwd) | 25-40% | 50-73% |
| GPT-2.7B 8k 端到端吞吐 | 175 TF/s | 225 TF/s |
| 相对 FA1 加速比 | 1× | 约 2× |
数学算法(在线 Softmax 递推、分块结构、IO 复杂度)完全一致——FA2 的所有改进均在实现层面,输出结果与 FA1 按位相同(在浮点舍入范围内)。
值得特别注意的一点:FA2 的每一项改进都是独立可拆解的。延迟 归一化、序列维度并行、Split-Q 三者互不依赖,可以独立打开或关闭。这使得消融实验(ablation)非常清晰,也使论文的结论高度可信——作者清楚地知道每一分速度来自哪里。
另一个值得关注的点:论文中没有 ablation table 明确拆分三项改进各自的贡献,这是一个小遗憾。从论文内容可以推断:序列维度并行主要在长序列、小批大小场景下贡献显著;Split-Q 在各种场景下均有稳定收益;延迟归一化的贡献在长序列下( 大时)更明显。