FlashAttention-2:更好的并行策略与线程块工作划分

笔记日期: 2026-05-24 笔记作者: Zhongzhu Zhou 论文标题: FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning 作者: Tri Dao(普林斯顿大学 / 斯坦福大学) arXiv: 2307.08691 状态 / Venue: arXiv 预印本(2023 年 7 月);ICLR 2024

一句话总结

FlashAttention-2 不是一个新算法,而是对 FlashAttention-1 实现层面的三处精准优化:减少非矩阵乘操作数、增加序列维度并行、改变 warp 内工作划分方式,把 A100 上的注意力吞吐量从 FA1 的 25-40% 峰值提升到 50-73% 峰值,实现了约 2× 加速,并使 GPT-2.7B 的 8k 上下文训练速度达到 225 TFLOPs/s(72% MFU)。

前置知识

1. GPU 内存层次

理解 FlashAttention 系列的起点是 GPU 的两级内存结构:

高带宽内存(HBM): GPU 的主 DRAM。A100-80GB 有 80 GB 容量,带宽约 2 TB/s。所有 PyTorch 张量默认存放在这里。

片上 SRAM(共享内存): 每个流式多处理器(SM)有独立的 SRAM。A100 每个 SM 有 192 KB,共 108 个 SM,总计约 20 MB SRAM。SRAM 带宽约 19 TB/s——大约是 HBM 的 10 倍。

带宽比:19 TB/s (SRAM)2 TB/s (HBM)10×(1)\text{带宽比:} \frac{19 \text{ TB/s (SRAM)}}{2 \text{ TB/s (HBM)}} \approx 10 \times \tag{1}

代价是 SRAM 极小(每 SM 只有 192 KB),且是易失存储。

FlashAttention 系列的核心设计原则: 最小化对 HBM 的读写次数,尽量在 SRAM 内完成计算。每一次不必要的 HBM 访问都比 SRAM 贵 10 倍。

graph TD
    CPU["主机 CPU"]
    subgraph A100["A100 GPU(108个SM)"]
        HBM["HBM:80 GB\n带宽约 2 TB/s\n(PyTorch 张量住这里)"]
        subgraph SM["单个 SM(×108)"]
            SRAM["片上 SRAM:192 KB\n带宽约 19 TB/s\n(10倍于HBM!)"]
            TC["Tensor Core\nFP16矩阵乘:312 TF/s\n非矩阵乘FP32:19.5 TF/s\n(差16倍!)"]
        end
    end
    CPU -->|"PCIe ~64 GB/s"| HBM
    HBM -->|"瓶颈!减少往返次数"| SRAM
    SRAM --> TC
    style HBM fill:#ff9999
    style SRAM fill:#99ff99
    style TC fill:#9999ff

图 1:A100 GPU 内存层次。HBM 是瓶颈——比 SRAM 慢 10 倍。FlashAttention 的设计核心是把注意力计算融合成单次 HBM 遍历(tiling),在 SRAM 内完成中间计算。

2. GPU 执行模型:线程、Warp、线程块

GPU 的线程组织是理解 FA2 第三项优化(split-Q vs split-K)的关键:

  • 线程(Thread): 最小执行单元,拥有寄存器。
  • Warp: 32 个线程的 SIMD 集合,在同一 SM 上锁步执行。Warp 内线程可以通过 shuffle 指令直接交换寄存器数据(极快,无需同步屏障)。
  • 线程块(Thread Block / CTA): 多个 Warp(通常 4-8 个)共享同一块 SRAM 分配区。Warp 间通信只能通过写 SRAM → __syncthreads() → 读 SRAM,需要显式同步屏障
  • 网格(Grid): 所有线程块。GPU 调度器把每个线程块分配给一个空闲 SM 执行。

占用率(Occupancy): 实际使用的 SM 资源(Warp 槽位、寄存器、SRAM)与最大可用资源的比值。占用率低意味着 SM 空转,硬件利用率不足。FA2 的序列维度并行改进(第三节)主要解决占用率不足问题。

graph TD
    GRID["GPU 核函数启动(线程块网格)"]
    GRID --> B0["线程块 0\n分配到 SM 0"]
    GRID --> B1["线程块 1\n分配到 SM 1"]
    GRID --> BN["线程块 N\n分配到 SM K"]
    B0 --> W0["Warp 0(32线程)\n寄存器shuffle: 极快"]
    B0 --> W1["Warp 1(32线程)"]
    B0 --> W2["Warp 2(32线程)"]
    B0 --> W3["Warp 3(32线程)"]
    W0 -->|"写/读 + syncthreads()\n有代价!"| SMEM["SRAM(192 KB)\n所有Warp共享"]
    W1 --> SMEM
    W2 --> SMEM
    W3 --> SMEM
    style SMEM fill:#ffff99

图 2:GPU 线程层次。Warp 间通信必须经过 SRAM,需要显式同步屏障。FA2 的 split-Q 策略使得前向传播中 Warp 间完全不需要通信。

3. 标准注意力的 O(N²) 问题

设 Q、K、V 的形状均为 N×dN \times dNN 为序列长度,dd 为头维度),标准注意力计算:

S=QKRN×N(2)S = QK^\top \in \mathbb{R}^{N \times N} \tag{2} P=softmax(S)RN×N(3)P = \text{softmax}(S) \in \mathbb{R}^{N \times N} \tag{3} O=PVRN×d(4)O = PV \in \mathbb{R}^{N \times d} \tag{4}

问题在于中间矩阵 SSPP 均为 N×NN \times N。以 N=8192N=8192(8k 上下文)、FP16 为例:

S=P=81922×2 B=128 MB(每个头,每条样本)|S| = |P| = 8192^2 \times 2 \text{ B} = 128 \text{ MB(每个头,每条样本)}

对于 40 头、batch size 4 的模型,注意力中间结果就需要 128 MB×40×4=20 GB128 \text{ MB} \times 40 \times 4 = 20 \text{ GB}

更糟的是,标准实现需要 3 次 HBM 往返:① 写 SS;② 读 SS、写 PP;③ 读 PP、写 OO。这是典型的 内存带宽瓶颈(memory-bound) 操作。

4. FlashAttention-1 回顾:IO 感知 + 在线 Softmax

FlashAttention-1(Dao et al., NeurIPS 2022)的核心思路:

分块(Tiling): 把 Q 按行分成大小为 BrB_r 的块,K/V 按列分成大小为 BcB_c 的块,逐块加载到 SRAM,避免将完整的 N×NN \times N 矩阵写入 HBM。

在线 Softmax(Online Softmax): Softmax 是按行耦合的——需要全行的 max 和 sum 才能归一化。在线 Softmax 维护一对运行统计量 (m(j),(j))(m^{(j)}, \ell^{(j)})(行内 running max 和 running 指数和),每处理一个 K/V 分块就更新一次:

m(j)=max(m(j1),rowmax(S(j)))(5)m^{(j)} = \max(m^{(j-1)}, \text{rowmax}(S^{(j)})) \tag{5} (j)=em(j1)m(j)(j1)+rowsum(eS(j)m(j))(6)\ell^{(j)} = e^{m^{(j-1)} - m^{(j)}} \ell^{(j-1)} + \text{rowsum}(e^{S^{(j)} - m^{(j)}}) \tag{6} O(j)=diag((j))1[em(j1)m(j)(j1)O(j1)+eS(j)m(j)V(j)](7)O^{(j)} = \text{diag}(\ell^{(j)})^{-1} \left[ e^{m^{(j-1)}-m^{(j)}} \ell^{(j-1)} O^{(j-1)} + e^{S^{(j)}-m^{(j)}} V^{(j)} \right] \tag{7}

这是 精确 的 Softmax——无任何近似。FA1 把 HBM 读写复杂度从 O(N2)O(N^2) 降至 O(N2d/M)O(N^2 d / M)MM 为 SRAM 大小)。

FA1 剩余的问题: 即使 IO 最优,FA1 在 A100 上仍只达到 25-40% 的理论峰值吞吐。对比之下,cuBLAS 的 GEMM 能达到 80-90%。差距来自三个实现层面的低效。

非矩阵乘操作的性能差距

A100 的计算单元存在显著的 不对称性

  • FP16/BF16 Tensor Core(矩阵乘): 312 TFLOPs/s
  • FP32 标量单元(非矩阵乘): 19.5 TFLOPs/s
效率比:31219.5=16×(矩阵乘每 FLOP 快 16 倍)(8)\text{效率比:} \frac{312}{19.5} = 16 \times \text{(矩阵乘每 FLOP 快 16 倍)} \tag{8}

FA1 的内层循环中,每处理一个 K/V 分块就需要执行:

  • max(m_old, rowmax(S_ij)) — 标量 max
  • exp(S_ij - m_new) — 逐元素 exp
  • diag(ℓ)^{-1} 对输出 OO 归一化 — 对角矩阵乘法

这些全部跑在慢速标量单元上,与 Tensor Core 争用时钟周期。序列越长、K/V 分块数 TcT_c 越多,这些操作的总代价越大。

flowchart TB
    subgraph "FA1 内层循环(每个 K/V 分块)"
        A["从HBM加载 K_j, V_j"]
        B["矩阵乘:S = Q_i × K_j^T\nTensor Core,312 TF/s"]
        C["非矩阵乘(慢!)\nm_new = max(m, rowmax(S))\n标量max,19.5 TF/s"]
        D["非矩阵乘(慢!)\nP̃ = exp(S - m_new)\nℓ_new = exp(...) × ℓ + rowsum(P̃)"]
        E["混合:归一化 O\ndiag(ℓ)^{-1} × [...]\n对角乘法(非矩阵乘)"]
        F["更新 m ← m_new; ℓ ← ℓ_new"]
        A --> B --> C --> D --> E --> F
    end
    style C fill:#ff9999
    style D fill:#ff9999
    style E fill:#ffcccc

图 3:FA1 内层循环中的非矩阵乘操作(红色),每个 K/V 分块执行一次完整的 diag(ℓ)^{-1} 归一化。FA2 将这一归一化推迟到外层循环结束后执行一次。

FA2 的三项核心改进

改进一:减少非矩阵乘操作(算法层面调整)

1a. 延迟归一化(前向传播)

FA1 在内层每步都对输出 OOdiag((j))1\text{diag}(\ell^{(j)})^{-1} 归一化,共执行 TcT_c 次。FA2 的关键发现:只需要在外层循环结束后统一做一次归一化。

FA1 的每步更新(含 \ell 归一化):

O(j)=diag((j))1[em(j1)m(j)(j1)O(j1)+eS(j)m(j)V(j)](9)O^{(j)} = \text{diag}(\ell^{(j)})^{-1} \left[ e^{m^{(j-1)}-m^{(j)}} \ell^{(j-1)} O^{(j-1)} + e^{S^{(j)}-m^{(j)}} V^{(j)} \right] \tag{9}

FA2 的每步更新(维护未归一化的 O~\tilde{O}):

O~(j)=diag(em(j1)m(j))1O~(j1)+eS(j)m(j)V(j)(10)\tilde{O}^{(j)} = \text{diag}(e^{m^{(j-1)}-m^{(j)}})^{-1} \tilde{O}^{(j-1)} + e^{S^{(j)}-m^{(j)}} V^{(j)} \tag{10}

内层循环只做 max 位移校正(emoldmnewe^{m_\text{old}-m_\text{new}}),不做 \ell 归一化。结束后:

O=diag((Tc))1O~(Tc)(仅一次)(11)O = \text{diag}(\ell^{(T_c)})^{-1} \tilde{O}^{(T_c)} \quad \text{(仅一次)} \tag{11}

正确性验证(双分块情形): 令初始 O~(0)=0\tilde{O}^{(0)} = 0

O~(1)=eS(1)m(1)V(1)(12)\tilde{O}^{(1)} = e^{S^{(1)}-m^{(1)}} V^{(1)} \tag{12}

处理第二个分块(m(2)=max(m(1),rowmax(S(2)))=mm^{(2)} = \max(m^{(1)}, \text{rowmax}(S^{(2)})) = m):

O~(2)=em(1)m(2)O~(1)+eS(2)m(2)V(2)=eS(1)mV(1)+eS(2)mV(2)(13)\tilde{O}^{(2)} = e^{m^{(1)}-m^{(2)}} \tilde{O}^{(1)} + e^{S^{(2)}-m^{(2)}} V^{(2)} = e^{S^{(1)}-m} V^{(1)} + e^{S^{(2)}-m} V^{(2)} \tag{13}

最终归一化:

O=O~(2)(2)=eS(1)mV(1)+eS(2)mV(2)eS(1)m1+eS(2)m1=softmax ⁣([S(1)S(2)])[V(1)V(2)](14)O = \frac{\tilde{O}^{(2)}}{\ell^{(2)}} = \frac{e^{S^{(1)}-m} V^{(1)} + e^{S^{(2)}-m} V^{(2)}}{e^{S^{(1)}-m}\mathbf{1} + e^{S^{(2)}-m}\mathbf{1}} = \text{softmax}\!\left([S^{(1)}\,S^{(2)}]\right)\begin{bmatrix}V^{(1)} \\ V^{(2)}\end{bmatrix} \tag{14} \checkmark

这与标准 Softmax 注意力结果完全一致。关键在于:max 位移校正确保分子各项相对大小正确,最终 \ell 提供正确的分母。

节省的 FLOP: FA1 对 TcT_c 个分块每次都执行 diag()1\text{diag}(\ell)^{-1},FA2 只做一次。对于 N=8192N=8192Bc=64B_c=64Tc=128T_c=128:FA2 消除了 127 次每步 \ell 归一化操作,每次操作的代价是 BrB_r 维标量除法(运行在 19.5 TF/s 的慢速单元上)。

1b. Logsumexp 压缩(后向传播)

FA1 为后向传播存储两个标量 (m(j),(j))(m^{(j)}, \ell^{(j)}),FA2 把它们压缩为一个:

L=m+log()=log ⁣(jeSij)(15)L = m + \log(\ell) = \log\!\left(\sum_j e^{S_{ij}}\right) \tag{15}

这就是逐行的 log-sum-exp。在后向传播中,从 LL 恢复注意力概率只需一步:

Pij=eSijLi(16)P_{ij} = e^{S_{ij} - L_i} \tag{16}

不需要先从 (m,)(m, \ell) 重建,减少了后向传播中的非矩阵乘指令数,同时把每行存储的标量数从 2 减少到 1,节省了 NN 个 FP16 的 HBM 写入量(不大但有益)。

FA2 前向传播算法(Algorithm 1):

算法 1:FlashAttention-2 前向传播
──────────────────────────────────────────────────────────────────────────
输入:Q, K, V ∈ ℝ^{N×d}(存于 HBM);分块大小 B_r, B_c
输出:O ∈ ℝ^{N×d}(存于 HBM);L ∈ ℝ^N(logsumexp,存于 HBM)
──────────────────────────────────────────────────────────────────────────
1.  Q 划分为 T_r = ⌈N/B_r⌉ 个行块 Q_1 … Q_{T_r}(每个 B_r×d)
    K, V 划分为 T_c = ⌈N/B_c⌉ 个列块 K_1 … K_{T_c}, V_1 … V_{T_c}

2.  for i = 1 … T_r do                 ← 外层循环 [FA2中每个行块是独立线程块]
3.      HBM→SRAM: 加载 Q_i
4.      初始化 Õ_i = 0,  ℓ_i = 0,  m_i = -∞

5.      for j = 1 … T_c do             ← 内层循环
6.          HBM→SRAM: 加载 K_j, V_j
7.          S_ij = Q_i K_j^T           ← 矩阵乘(Tensor Core)
8.          m_i_new = max(m_i, rowmax(S_ij))  ← 非矩阵乘
9.          P̃_ij = exp(S_ij − m_i_new)       ← 非矩阵乘
10.         ℓ_i_new = exp(m_i−m_i_new)·ℓ_i + rowsum(P̃_ij)  ← 非矩阵乘
11.         Õ_i = diag(exp(m_i−m_i_new))^{-1} Õ_i + P̃_ij V_j
                                        ← max校正(非矩阵乘) + 矩阵乘
            ── FA2 关键:此处无 diag(ℓ)^{-1} ──
12.         m_i ← m_i_new;  ℓ_i ← ℓ_i_new
13.     end for

14.     O_i = diag(ℓ_i)^{-1} Õ_i      ← 每个外层步仅归一化一次!
15.     L_i = m_i + log(ℓ_i)           ← 存logsumexp供后向传播使用
16.     SRAM→HBM: 写回 O_i, L_i
17. end for
──────────────────────────────────────────────────────────────────────────

与 FA1 相比,第 14 行的 diag()1\text{diag}(\ell)^{-1} 从内层循环移到了外层循环,一次完成。

FA2 后向传播算法(Algorithm 2):

算法 2:FlashAttention-2 后向传播
──────────────────────────────────────────────────────────────────────────
输入:Q, K, V, O, dO ∈ ℝ^{N×d}(HBM);L ∈ ℝ^N(HBM)
输出:dQ, dK, dV ∈ ℝ^{N×d}(HBM)
──────────────────────────────────────────────────────────────────────────
预计算:D = rowsum(dO ◦ O) ∈ ℝ^N
  (Softmax 雅可比的缩放项:D_i = ∑_k dO_ik·O_ik)

1.  for j = 1 … T_c do                ← 外层循环(对每个 K/V 列块)
2.      加载 K_j, V_j : HBM→SRAM
3.      初始化 dK_j = 0,  dV_j = 0

4.      for i = 1 … T_r do            ← 内层循环
5.          加载 Q_i, O_i, dO_i, L_i, D_i : HBM→SRAM
6.          S_ij = Q_i K_j^T           ← 矩阵乘
7.          P_ij = exp(S_ij − L_i)     ← 用 logsumexp 重算注意力概率
8.          dV_j += P_ij^T · dO_i      ← 矩阵乘
9.          dP_ij = dO_i · V_j^T       ← 矩阵乘
10.         dS_ij = P_ij ◦ (dP_ij − D_i) ← Softmax 雅可比(非矩阵乘)
11.         dQ_i += dS_ij · K_j        ← 矩阵乘(原子加回 HBM)
12.         dK_j += dS_ij^T · Q_i      ← 矩阵乘
13.     end for

14.     写回 dK_j, dV_j → HBM
15. end for
──────────────────────────────────────────────────────────────────────────

第 7 行展示了 logsumexp 的作用:直接从存储的 LiL_i 重算 PijP_{ij},无需从 HBM 读取完整的 N×NN \times N 矩阵 PP。第 11 行用原子加(atomic add)是因为 dQidQ_i 接受来自所有列块 jj 的贡献,需要跨线程块安全累加。

Softmax 雅可比第 10 行利用了以下恒等式:若 p=softmax(s)p = \text{softmax}(s),则:

ds=p(dpp,dp1)=p(dpD)(17)ds = p \circ (dp - \langle p, dp \rangle \cdot \mathbf{1}) = p \circ (dp - D) \tag{17}

其中 D=p,dpD = \langle p, dp \rangle 即预计算的行缩放值。这确保了后向传播也是精确的。

改进二:序列维度并行

FA1 沿两个维度并行:Batch 维度和 Head 维度,共启动 B×HB \times H 个线程块。对于长上下文推理场景(Batch=1, H=32):

B×H=32 个线程块108 个 SM    76 个 SM 空转(18)B \times H = 32 \text{ 个线程块} \ll 108 \text{ 个 SM} \implies 76 \text{ 个 SM 空转} \tag{18}

FA2 新增序列维度并行:外层循环对每个 Q 行块 QiQ_i 是独立的,FA2 把每个行块分配给一个独立线程块,同时在不同 SM 上并行执行。

FA2 线程块数:B×H×Tr其中Tr=N/Br(19)\text{FA2 线程块数:} B \times H \times T_r \quad \text{其中} \quad T_r = \lceil N/B_r \rceil \tag{19}

N=8192N=8192Br=64B_r=64Tr=128T_r=128,线程块数增加 128 倍,SM 占用率大幅提升。

graph LR
    subgraph FA1_para["FA1:2维并行"]
        direction TB
        T1["线程块0 = Head0, Batch0\n处理全部128个行块迭代\n(串行)"]
        T2["线程块1 = Head1, Batch0"]
        TN["线程块31 = Head31\n共32个线程块 < 108 SM"]
    end
    subgraph FA2_para["FA2:3维并行"]
        direction TB
        T2_0["线程块0 = Head0, Batch0, 行块0"]
        T2_1["线程块1 = Head0, Batch0, 行块1"]
        T2_N["线程块4095 = Head31, 行块127\n共32×128=4096个线程块"]
    end
    FA1_para -->|"长上下文:B×H小\n大量SM闲置"| LOW["低占用率"]
    FA2_para -->|"T_r倍更多线程块\n填满全部SM"| HIGH["高占用率"]
    style LOW fill:#ff9999
    style HIGH fill:#99ff99

图 4:FA2 新增序列维度并行。FA1 的 32 个线程块使 A100 上 76 个 SM 空转。FA2 额外引入行块维度,创建 4096 个线程块,充分填满所有 SM。

前向传播: 外层循环各行块完全独立,无线程块间通信,理想并行。

后向传播: 外层循环迭代 K/V 列块。每个线程块处理一个 Kj/VjK_j/V_j 列块,遍历全部行块计算 dKjdK_jdVjdV_j。但 dQidQ_i 需要从全部列块累加,FA2 使用原子加处理跨线程块的 dQdQ 更新(式 11)。

这一思路最早由 OpenAI 的 Phil Tillet 在 Triton 的 FlashAttention 实现中提出,FA2 在 CUDA 内核中将其规范化。

改进三:Warp 内工作划分(Split-Q vs Split-K)

即使线程块级的并行度已经最优,线程块内的 4-8 个 Warp 如何分工仍然影响性能。FA2 的第三项改进就是改变这里的划分策略。

FA1 的 Split-K 策略

FA1 把 K/V 的头维度 dd 分配给各个 Warp:

  • Warp 0 处理 Kj,0:d/4K_{j,0:d/4}(K 的前 1/4 列)
  • Warp 1 处理 Kj,d/4:d/2K_{j,d/4:d/2}
  • Warp 2、3 类似

每个 Warp 得到 QKQK^\top 的一个部分列切片,与对应的 VV 切片相乘得到部分输出。最终合并各 Warp 的部分输出需要经过 SRAM:

  1. 每个 Warp 把中间结果写入 SRAM
  2. 全部 Warp 执行 __syncthreads() 同步
  3. 汇总并写回

问题: Softmax 是按行耦合的——每个 Warp 只看到 SijS_{ij} 的部分列,无法独立计算完整的 max\max\sum。必须先聚合各 Warp 的 SS 切片,或者在 Warp 间共享 mm\ell 统计量,这又增加了额外的 SRAM 通信轮次。

FA2 的 Split-Q 策略

FA2 改为把 Q 的行维度 BrB_r 分配给各个 Warp:

  • Warp 0 处理 Q 的第 0-15 行,独立计算这些行的完整注意力和输出
  • Warp 1 处理 Q 的第 16-31 行
  • Warp 2、3 类似

所有 Warp 共享 K 和 V 的同一块(由线程块统一加载到 SRAM,只读),但每个 Warp 独立维护自己所负责行的 (m,,O~)(m, \ell, \tilde{O})

关键: Softmax 按行计算。每个 Warp 拥有完整的若干行,可以独立执行在线 Softmax——不需要和其他 Warp 交换任何数据。零 Warp 间通信,无需 SRAM 写入,无需同步屏障。

graph TD
    subgraph fa1["FA1:Split-K"]
        direction LR
        subgraph W0K["Warp 0"]
            q0["Q_i(完整 B_r 行)"]
            k0["K_j 列切片 [0:d/4]"]
            p0["部分 QK^T"]
        end
        subgraph W1K["Warp 1"]
            q1["Q_i(同)"]
            k1["K_j [d/4:d/2]"]
            p1["部分 QK^T"]
        end
        SMEM_K["SRAM: 各Warp写部分结果\n+ syncthreads + 求和\n(瓶颈!)"]
        O_K["最终 O_i"]
        p0 --> SMEM_K
        p1 --> SMEM_K
        SMEM_K --> O_K
    end
    subgraph fa2["FA2:Split-Q"]
        direction LR
        subgraph W0Q["Warp 0"]
            q0q["Q_i 行切片 [0:B_r/4]"]
            kv0["K_j, V_j(完整,SRAM 只读)"]
            out0["完整输出 O_i[0:B_r/4]"]
        end
        subgraph W1Q["Warp 1"]
            q1q["Q_i [B_r/4:B_r/2]"]
            kv1["K_j, V_j(同)"]
            out1["完整输出 O_i[B_r/4:B_r/2]"]
        end
        O_Q["最终 O_i\n(无需同步!)"]
        kv0 -.-> q0q
        kv1 -.-> q1q
        out0 --> O_Q
        out1 --> O_Q
    end
    style SMEM_K fill:#ff9999
    style O_Q fill:#99ff99

图 5:FA1 的 Split-K(上)需要所有 Warp 把部分结果写入 SRAM 并同步。FA2 的 Split-Q(下)每个 Warp 独立处理完整行,K/V 由 SRAM 只读共享,完全无 Warp 间通信。

为什么 Split-Q 能奏效: Softmax 沿行方向耦合,各行之间是独立的。只要每个 Warp 负责完整的若干行(不跨 Warp 分割行),它就能独立完成在线 Softmax——mm\ellO~\tilde{O} 全部存在该 Warp 的寄存器中,不需要与其他 Warp 共享。

设计边界: Split-Q 的代价是:所有 4 个 Warp 需要同时访问 K 和 V 的同一个分块,这些分块必须完整地放入 SRAM。这对 BcB_c(列分块大小)形成了上限约束。FA1 的 Split-K 可以用更大的 BcB_c(每个 Warp 只访问 K 的一部分),但 FA2 的收益远超这一代价。

后向传播中 Softmax 雅可比的推导

后向传播算法中第 10 步的 dS_ij = P_ij ◦ (dP_ij − D_i) 涉及 Softmax 的雅可比矩阵。这里完整推导其来源,这是后向传播正确性的关键。

Softmax 的雅可比:p=softmax(s)p = \text{softmax}(s),即 pk=esk/mesmp_k = e^{s_k}/\sum_m e^{s_m}。对损失 L\mathcal{L}sis_i 求偏导:

Lsi=kLpkpksi(27)\frac{\partial \mathcal{L}}{\partial s_i} = \sum_k \frac{\partial \mathcal{L}}{\partial p_k} \frac{\partial p_k}{\partial s_i} \tag{27}

计算 pk/si\partial p_k / \partial s_i

pksi={pk(1pk)k=ipkpiki=pk(δkipi)(28)\frac{\partial p_k}{\partial s_i} = \begin{cases} p_k(1 - p_k) & k = i \\ -p_k p_i & k \neq i \end{cases} = p_k(\delta_{ki} - p_i) \tag{28}

代入:

Lsi=kdpkpk(δkipi)=dpipipikdpkpk=pi(dpikpkdpk)(29)\frac{\partial \mathcal{L}}{\partial s_i} = \sum_k dp_k \cdot p_k (\delta_{ki} - p_i) = dp_i \cdot p_i - p_i \sum_k dp_k \cdot p_k = p_i \left(dp_i - \sum_k p_k \, dp_k\right) \tag{29}

其中 kpkdpk=p,dp\sum_k p_k \, dp_k = \langle p, dp \rangle 即向量点积。写成矩阵形式:

ds=p(dpp,dp1)(30)ds = p \circ (dp - \langle p, dp \rangle \cdot \mathbf{1}) \tag{30}

在注意力后向传播中的对应: pPijp \to P_{ij}(一行注意力权重),dpdPijdp \to dP_{ij}=dOiVj= dO_i V_j^\top),p,dpDi=rowsum(dOiOi)\langle p, dp \rangle \to D_i = \text{rowsum}(dO_i \circ O_i)(预先计算好)。

预计算 DD 而非在内层循环中实时计算的好处:DD 只依赖 dOdOOO,与 jj(列块索引)无关,可以一次性算出所有行的 DD 向量,避免 TcT_c 次重复计算。

多 Query 注意力(MQA)和分组 Query 注意力(GQA)的支持

现代大模型推理中,KV 缓存常占据大量显存。减小 KV 缓存的主流方法:

多 Query 注意力(MQA)(Shazeer, 2019): 所有 Query 头共享同一组 K 和 V(只有 1 个 K/V 头)。KV 缓存缩小 HH 倍(HH 为 Query 头数)。

分组 Query 注意力(GQA)(Ainslie et al., 2023): GG 组,每组的 Query 头共享一个 K/V 头。G=1G=1 退化为 MQA,G=HG=H 退化为标准多头注意力(MHA)。

FA2 通过张量步长(stride)技巧原生支持这两者,无需在内存中复制 K/V:第 ii 个 Query 头属于第 g=i/Gg = \lfloor i/G \rfloor 组,FA2 在索引 K/V 时直接使用 gg,而不是 ii。这是一个零内存开销的实现,不需要在内核外手动广播 K/V。

后向传播中的额外处理: 由于多个 Query 头共享同一个 K/V 头,梯度 dKdKdVdV 需要在组内的所有 Query 头上求和。FA2 在后向传播核函数之后附加一个规约(reduction)操作完成这一求和。

因果掩码优化

自回归语言模型需要因果掩码:位置 ii 的 query 只能 attend 到位置 jij \leq i。在分块框架下,FA2 利用两种情况:

情形一——完全被掩码的分块: 如果某个分块的所有列索引 jj 都严格大于行索引 ii(即整块都在对角线以上),则所有条目经掩码后变为 -\infty,Softmax 后贡献为零。FA2 直接跳过该分块的 GEMM 和 HBM 加载

对于大序列,约一半的分块是完全被掩码的,跳过后比无掩码注意力快约 1.7-1.8×

情形二——边界分块: 仅对角线分块(部分有效、部分掩码)需要显式应用因果掩码,每个行块只有 1 个这样的分块。对角线以下的分块全部无掩码,无需任何掩码逻辑。

graph LR
    subgraph mask["N×N 因果注意力矩阵(分块视角)"]
        ABOVE["上三角分块(j > i)\n全部掩码 → 跳过!\n约50%的分块"]
        DIAG["对角线分块(j ≈ i)\n部分掩码 → 显式处理\n每行块仅1个"]
        BELOW["下三角分块(j < i)\n全部有效 → 全速计算\n无掩码逻辑"]
    end
    ABOVE -->|"跳过约50%计算"| SPEEDUP["约1.7-1.8×\n速度提升"]
    style ABOVE fill:#cccccc
    style SPEEDUP fill:#99ff99
    style DIAG fill:#ffff99
    style BELOW fill:#ccffcc

图 6:因果掩码在 FA2 分块框架中的处理。上三角分块直接跳过,对角线分块施加掩码,下三角分块全速执行。这带来约 1.7× 相对于非因果注意力的速度提升。

算术强度与屋顶线模型

屋顶线模型(Roofline Model) 提供了理解注意力性能上界的分析框架。对于计算吞吐 π\pi(FLOPs/s)和内存带宽 β\beta(bytes/s)的硬件:

Pmin(π,  βI)(20)P \leq \min(\pi,\; \beta \cdot I) \tag{20}

其中 II算术强度(每字节内存传输对应的 FLOPs 数)。A100 的”脊线点”:

Iridge=312 TF/s2 TB/s=156 FLOPs/byte(21)I_\text{ridge} = \frac{312 \text{ TF/s}}{2 \text{ TB/s}} = 156 \text{ FLOPs/byte} \tag{21}

标准注意力的算术强度:

  • 计算量:O(N2d)O(N^2 d)
  • 内存访问:O(N2)O(N^2)(读写 S、P 矩阵)
  • 算术强度:d=128\approx d = 128(FP16,128/2 = 64 FLOPs/byte)

64 FLOPs/byte < 156 FLOPs/byte:标准注意力是内存带宽瓶颈(memory-bound)

FlashAttention 的算术强度:

  • 计算量:仍为 O(N2d)O(N^2 d)
  • 内存访问:降至 O(Nd)O(Nd)(不写 N×NN \times N 矩阵)
  • 算术强度:N\approx NN=8192N=8192 时为 4096 FLOPs/byte)

4096 ≫ 156:FA1/FA2 是计算瓶颈(compute-bound),能够接近峰值 Tensor Core 吞吐。

FA2 进一步提升了计算瓶颈下的效率:减少非矩阵乘 FLOP,让 Tensor Core 占用更高比例的总时间。

实验结果

注意力核吞吐量

所有测试在 A100 80GB SXM4 上进行。序列长度从 512 到 16k,批大小固定使总 token 数为 16k,隐层维度 2048。

前向+后向速度,无因果掩码,头维度 64(A100):

方法5121k2k4k8k16k
PyTorch 标准实现6891989583OOM
FlashAttention-19190104989276
FA Triton10292110110108100
FlashAttention-2132162171176175173

单位:TFLOPs/s。A100 FP16 峰值 = 312 TFLOPs/s。

有因果掩码,头维度 128 的前向速度: FA2 最高达 224-227 TFLOPs/s,相当于 A100 峰值的 72-73%,已接近纯矩阵乘 cuBLAS 的效率(80-90%)。

FA2 在 16k 长序列上更明显:此时批大小受内存限制通常很小,FA2 的序列维度并行恰好填补了因批大小小导致的 SM 闲置问题。

端到端 GPT 训练吞吐量

GPT 模型在 8× A100 80GB SXM4 上训练结果:

模型上下文长度无 FAFA1FA2FA2 vs 基线加速比
GPT3-1.3B2k142 TF/s1891961.38×
GPT3-1.3B8k72 TF/s1702203.06×
GPT3-2.7B2k149 TF/s1892051.38×
GPT3-2.7B8k80 TF/s1752252.81×

8k 上下文下加速比最大(3×),原因正是:长序列迫使批大小缩小,而 FA2 的序列维度并行解决了由此产生的 SM 低占用率问题。

MFU=225 TFLOPs/s312 TFLOPs/s=72%(22)\text{MFU} = \frac{225 \text{ TFLOPs/s}}{312 \text{ TFLOPs/s}} = 72\% \tag{22}

72% 的模型 FLOPs 利用率对真实训练任务(含数据加载、优化器、梯度通信)而言非常高,已接近硬件上限。

前向传播与后向传播的 IO 复杂度分析

为了从理论上量化 FA2 相对于标准注意力的优势,我们分析每次前向传播的 HBM 读写字节数(以 FP16,即每元素 2 字节计)。

标准注意力 HBM 访问:

  • 读 Q, K, V: 3×2Nd3 \times 2Nd 字节
  • S=QKS = QK^\top: 2N22N^2 字节
  • SS,写 P=softmax(S)P = \text{softmax}(S): 2×2N22 \times 2N^2 字节
  • 读 P, V,写 O: 2N2+2Nd+2Nd2N^2 + 2Nd + 2Nd 字节
  • 总计:4×2N2+6Nd\approx 4 \times 2N^2 + 6Nd 字节NdN \gg d 时以 O(N2)O(N^2) 主导)

FlashAttention(FA1/FA2)HBM 访问:

  • 读 Q: 2Nd2Nd 字节(分 TrT_r 块,每外层步读一次)
  • 每内层步读 K_j, V_j: 2×2Bcd2 \times 2 B_c d 字节,共 Tr×TcT_r \times T_c
  • 每外层步写 O_i: 2Brd2B_r d 字节,共 TrT_r
  • 写 L(logsumexp): 2N2N 字节(可忽略)
  • 总计:(Tr22BcdTc)+2Nd+2Nd=O(NdTc+Nd)\approx (T_r \cdot 2 \cdot 2B_c d \cdot T_c) + 2Nd + 2Nd = O(Nd \cdot T_c + Nd)

由于 Tc=N/BcT_c = N/B_c,总 HBM 访问量为 O(NdN/Bc)=O(N2d/Bc)O(Nd \cdot N/B_c) = O(N^2 d / B_c),与 SRAM 大小 MBcdM \sim B_c \cdot d 成反比:

FA HBM 访问量=O ⁣(N2dM)vs 标准注意力的 O(N2)(25)\text{FA HBM 访问量} = O\!\left(\frac{N^2 d}{M}\right) \quad \text{vs 标准注意力的 } O(N^2) \tag{25}

对于 MdM \gg d(SRAM 能容纳足够大的 BcB_c),FA 的 HBM 流量大幅低于标准注意力。以 N=8192,d=128,M192KB=98304B,Bc=64N=8192, d=128, M\approx 192\text{KB}=98304\text{B}, B_c=64(故 M/(2d)=384M/(2d)=384)为例:

标准:4N2=4×81922268 MBvs FA:8Nd/Tc34 MB(26)\text{标准:} 4N^2 = 4 \times 8192^2 \approx 268 \text{ MB} \quad \text{vs FA:} \approx 8Nd/T_c \approx 34 \text{ MB} \tag{26}

FA 将 HBM 流量降低约 8 倍——这就是为什么 FA1 即使在算法层面没有 FA2 的改进,也能相对标准注意力取得显著加速:它把注意力从 memory-bound 变成了 compute-bound。

设计分析:为什么、替代方案、边界条件

为什么不把所有分块处理完再一次性归一化?

问题: 能否先处理全部 K/V 分块,累积未归一化的和,最后做一次归一化?

答案: 不行——数值溢出。如果不做在线 max 减法,eSije^{S_{ij}} 对大值会溢出 FP16/FP32。在线 max 技巧(每步维护 eSijme^{S_{ij}-m},始终 ≤1)是数值稳定性的必要条件。FA2 保留了每步的 max 校正(相对廉价,仅是一个 1\leq 1 的标量乘法),只去掉了更昂贵的每步 \ell 除法。

边界: TcT_c 越大(序列越长),节省越显著。对 512 token 的短序列(Tc=8T_c=8),仅省 7 次归一化,收益有限。对 16k token(Tc=256T_c=256),省 255 次,效果明显。

为什么 Split-Q 可以无 Warp 通信而 Split-K 不行?

核心差异在于 Softmax 的耦合方向:

  • Split-K: 每个 Warp 拥有 SijS_{ij} 的部分切片。Softmax 按行计算,需要完整行的所有列才能求 max 和 sum。Warp 间必须通信来聚合统计量或部分输出。

  • Split-Q: 每个 Warp 拥有 SijS_{ij} 的部分。每行的 Softmax 完全包含在单个 Warp 内。(m,,O~)(m, \ell, \tilde{O}) 均为该 Warp 私有的寄存器,无需共享。

边界: Split-Q 要求 K 和 V 的分块能同时放入 SRAM 并被所有 Warp 访问,对 BcB_c 有上限。FA1 的 Split-K 每个 Warp 只需 K 的 1/4,可以使用更大的 BcB_c。但 FA2 消除 Warp 通信的收益远超减小 BcB_c 的代价。

为什么序列维度并行对长序列更有效?

短序列(512-2k): B×HB \times H 通常已足够大(例如 batch=32 × heads=32 = 1024 个线程块),108 个 SM 已经能被充分利用。序列维度并行增加的线程块带来的增量有限。

长序列(4k-16k): 内存限制迫使 batch 缩小(batch=1 到 4 很常见)。B×HB \times H 可能只有 32-64 个线程块,远少于 108 个 SM。此时序列维度并行的 TrT_r 倍增加效果显著——这也解释了为什么 FA2 在端到端测试中 8k 上下文的加速比(3×)远大于 2k 上下文(1.38×)。

FA1 → FA2 → FA3 全景

性质FlashAttention-1FlashAttention-2FlashAttention-3(参考)
HBM IO 复杂度O(N2d/M)O(N^2 d/M)
每步 \ell 归一化每内层步执行仅外层步一次同 FA2
后向统计量(m,)(m, \ell),2个/行LL,1个/行同 FA2
序列维度并行有,更细粒度
Warp 策略Split-KSplit-QWarp 专业化(producer/consumer)
Warp 间通信需要
异步内存拷贝TMA(硬件异步加载)
FP8 支持
A100 峰值占比(fwd+bwd)25-40%50-73%N/A(主要针对H100)
H100 FP16 峰值(fwd+bwd)~30%~35%(直接移植)~75%

FA3 的主要新增是 Hopper 专属特性(TMA、Warp 专业化、第四代 Tensor Core),对 Ampere GPU 意义不大。对于 A100 用户,FA2 是最优选择。

实现细节与分块大小调优

分块大小的约束

BrB_r(Q 行分块)和 BcB_c(K/V 列分块)需满足同时放入 SRAM 的约束:

Brd2+2Bcd2+BrBc2MSRAM=192 KB(23)B_r \cdot d \cdot 2 + 2 \cdot B_c \cdot d \cdot 2 + B_r \cdot B_c \cdot 2 \leq M_\text{SRAM} = 192 \text{ KB} \tag{23}

(系数 2 对应 FP16,Q 块 + K 块 + V 块 + S 块)

同时还受到 Tensor Core 对齐要求(维度需为 16 的倍数)和寄存器压力约束。FA2 目前用手工调优的查找表(按头维度索引),典型值:Br=Bc=64B_r = B_c = 64(头维度 64),Br=64,Bc=64B_r = 64, B_c = 64 或 128(头维度 128)。

Logsumexp 的数值稳定性

FA2 存储 L=m+logL = m + \log\ell,其中:

  • =jeSijm(0,Tc]\ell = \sum_j e^{S_{ij}-m} \in (0, T_c](始终正有界)
  • log(,logTc]\log\ell \in (-\infty, \log T_c](有界,无溢出)

后向传播中从 LL 还原 Pij=eSijLiP_{ij} = e^{S_{ij} - L_i}

SijLi=Sijmilogi0(24)S_{ij} - L_i = S_{ij} - m_i - \log\ell_i \leq 0 \tag{24}

因为 mi=maxkSikm_i = \max_k S_{ik},所以 eSijLi(0,1]e^{S_{ij}-L_i} \in (0, 1],不会溢出。这保证了后向传播与前向传播数值精度一致。

对 LLM 训练和推理的实际影响

上下文长度扩展的意义

FA2 最直接的影响是让长上下文训练变得经济可行。注意力的计算量随序列长度二次增长,但 FA2 通过序列维度并行使得实际 GPU 吞吐随上下文增加的衰减远小于理论预期:

上下文长度注意力 FLOPs 增倍FA2 实测吞吐每 token 等效成本
2k~196 TF/s
8k16×(注意力部分)~220 TF/s(+12%)约1.1× per token
16k64×(注意力部分)~224 TF/s类似

原因:更长序列迫使每卡批大小缩小(内存压力),FA2 的 TrT_r 倍更多线程块恰好填补了因批大小小导致的 SM 闲置。换言之,FA2 的效率提升越长越大,完美对冲了长上下文带来的内存限制。

对 KV 缓存的影响

在自回归推理的解码阶段(逐 token 生成),每步的 query 形状为 1×d1 \times d(只有一个新 token),需要 attend 到长度 NN 的 KV 缓存。此时操作从矩阵-矩阵乘退化为向量-矩阵乘,Tensor Core 利用率自然下降——FA2 的优势主要体现在:

  1. 预填充阶段(Prefill): 处理完整 prompt(NN 个 token)时,FA2 的效率提升完全生效。
  2. MQA/GQA 支持: FA2 原生支持多 query 注意力(多个 Q 头共享一组 K/V),使推理时 KV 缓存大小可减少 HH(头数)倍,直接降低 GPU 内存需求和带宽压力。

与 PyTorch 的集成

PyTorch 2.0+ 中,torch.nn.functional.scaled_dot_product_attention 在以下条件下自动调用 FA2:

  • CUDA(Ampere 及以上,即 A100/A10G/RTX 3090 等)
  • 输入为 FP16 或 BF16
  • 序列长度不太短(通常 >128)

这意味着绝大多数现代 LLM 训练无需任何代码改动即可获得 FA2 的性能提升。如需最新优化(H100 FA3、FP8、自定义融合 op),使用 flash-attn 包。

可复现性说明

论文测试环境:A100 80GB SXM4,CUDA 11.8,cuDNN 8.7,PyTorch 2.0。关键实现细节:

  • 分块大小: 头维度 64 → Br=Bc=64B_r=B_c=64;头维度 128 → Br=64,Bc=64B_r=64, B_c=64(部分配置为 128)
  • 参考实现: github.com/Dao-AILab/flash-attention(Apache 2.0 许可)
  • PyTorch 集成: pip install flash-attn==2.x.x;也已合并进 PyTorch 核心
  • FLOP 计算惯例: 注意力 FLOPs 按 4N2d4N^2d(前向)和 10N2d10N^2d(后向,为前向 2.5×)计算。因果注意力的 per-second 效率计算中将前向 FLOPs 除以 2,但端到端 MFU 计算遵从 Megatron-LM 惯例,不除以 2,以保持与文献一致性
  • MFU 计算: MFU=实测 TFLOPs/s/峰值 TFLOPs/s\text{MFU} = \text{实测 TFLOPs/s} / \text{峰值 TFLOPs/s};GPT-2.7B 在 8k 上下文下:225/312=72%225/312 = 72\%

约 2× 的加速提升在不同头维度、序列长度和批大小下均稳健。绝对 TFLOPs/s 数值对 A100 型号(SXM4 vs PCIe)、驱动版本和热节流状态敏感,实际测量期望 ±5-10% 的偏差。

总结

FlashAttention-2 是一篇关于 GPU 系统优化精确性的精彩示范。Tri Dao 诊断出 FA1 三个独立的性能瓶颈,并为每个瓶颈设计了最小化的修复方案:

  1. 延迟 \ell 归一化: 消除内层循环中每步的对角矩阵乘法,将 TcT_c 次降至 1 次
  2. 序列维度并行: 为外层行块引入独立线程块,对长上下文场景填满空转 SM,提升占用率
  3. Split-Q Warp 策略: 基于”Softmax 按行独立”的观察,消除 Warp 间 SRAM 通信和同步屏障

三项改进叠加,实现了 FA1 的 2× 加速,使 A100 上注意力核效率从 25-40% 提升到 50-73% 的理论峰值,与纯矩阵乘的效率差距缩小到 10-20 个百分点。端到端结果:GPT-2.7B 在 8k 上下文下训练吞吐达到 225 TFLOPs/s(72% MFU),比无 FA 基线快 2.8×,比 FA1 快 1.3×。

这篇论文的示范价值在于:它证明了”IO 最优”(FA1 实现)和”计算最优”(FA2 进一步优化)是两个独立的问题,不能混为一谈。FA1 消除了 O(N2)O(N^2) HBM 流量,使操作从 memory-bound 变成 compute-bound;FA2 则在 compute-bound 的假设下,通过减少非矩阵乘 FLOP、提高占用率、优化 Warp 通信,进一步接近 Tensor Core 的峰值利用率。这套分析框架——先找到瓶颈类型,再有针对性地解决——是 GPU 系统优化的通用方法论。

对实践者的建议: PyTorch 2.0+ 已将 FA2 作为 CUDA 上 scaled_dot_product_attention 的默认后端,无需代码改动即可受益。如需最新优化(H100 FA3、FP8、自定义 MQA/GQA),使用 flash-attn 库的最新版本。

关于 H100: 将 FA2 的同一代码直接移植到 H100 SXM5 上(不做任何 H100 专项优化),可达 294-338 TFLOPs/s,约为 H100 FP16 峰值(989 TFLOPs/s)的 30-34%。利用 H100 专属特性(TMA、Warp 专业化、FP8)的 FA3 能达到约 75% 峰值。因此对于 H100/H200 用户,FA3 是更好的选择;对于 A100/A10G/RTX 系列,FA2 是当前最优解。

(注:FA2 是 Apache 2.0 开源许可,代码地址:github.com/Dao-AILab/flash-attention

思维框架:GPU 注意力性能受三个独立维度制约——HBM 带宽(FA1 解决)、SM 占用率(FA2 的序列并行解决)、Warp 效率(FA2 的 Split-Q 解决)。理解这三个维度是分析任何注意力优化工作的基础框架。

读完这篇论文,我最大的收获不是具体的技术手段,而是这套分层诊断的思路:先用屋顶线模型判断瓶颈类型(memory-bound 还是 compute-bound),再分析 compute-bound 内部的次级瓶颈(非矩阵乘 FLOP 占比、SM 占用率、Warp 通信),最后对每个瓶颈设计独立的、可量化的修复方案。这种工程文化在国内的 AI infra 团队中比较少见,值得借鉴。

从影响力的角度看,FlashAttention 系列是近年来对实际 LLM 训练和推理贡献最大的少数几篇论文之一——不是通过提出新模型或新算法,而是通过让已有算法在现有硬件上运行得更快。这提醒我们,系统级优化与算法创新同等重要,在资源受限的环境下甚至更为关键。

FA1 vs FA2 完整对照表

性质FlashAttention-1FlashAttention-2
HBM IO 复杂度(前向)O(N2d/M)O(N^2 d / M)同,不变
每内层步 \ell 归一化✓ 每 K/V 分块执行✗ 仅外层步结束执行一次
后向统计量存储(m,)(m, \ell) — 2 个标量/行L=m+logL = m+\log\ell — 1 个标量/行
序列维度并行✗ 外层循环串行✓ 每行块独立线程块
线程块总数B×HB \times HB×H×TrB \times H \times T_r
Warp 工作划分(前向)Split-K:按 K/V 列维分割Split-Q:按 Q 行维分割
Warp 间通信(前向)✓ 需要 SRAM 写 + syncthreads✗ 完全无需
后向 dQ 更新线程块内串行跨线程块原子加
MQA/GQA 原生支持✓ 步长技巧
因果掩码优化✓ 跳过全掩码分块✓ 同,不变
A100 峰值利用率(fwd+bwd)25-40%50-73%
GPT-2.7B 8k 端到端吞吐175 TF/s225 TF/s
相对 FA1 加速比约 2×

数学算法(在线 Softmax 递推、分块结构、IO 复杂度)完全一致——FA2 的所有改进均在实现层面,输出结果与 FA1 按位相同(在浮点舍入范围内)。

值得特别注意的一点:FA2 的每一项改进都是独立可拆解的。延迟 \ell 归一化、序列维度并行、Split-Q 三者互不依赖,可以独立打开或关闭。这使得消融实验(ablation)非常清晰,也使论文的结论高度可信——作者清楚地知道每一分速度来自哪里。

另一个值得关注的点:论文中没有 ablation table 明确拆分三项改进各自的贡献,这是一个小遗憾。从论文内容可以推断:序列维度并行主要在长序列、小批大小场景下贡献显著;Split-Q 在各种场景下均有稳定收益;延迟归一化的贡献在长序列下(TcT_c 大时)更明显。