news 2026/7/1 7:30:14

【学习笔记】Flash Attention 原理与实践:让 Attention 重新成为算力游戏(13/35)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【学习笔记】Flash Attention 原理与实践:让 Attention 重新成为算力游戏(13/35)

在前面的文章里,Flash Attention 这个名字反复出现:

  • 第 2 篇讲 attention 时提到它是「现代推理框架的标配」

  • 第 5 篇讲长上下文时把它列为「四大攻坚维度」之一

  • 第 11 篇讲推理优化时它是 prefill 阶段的核心加速器

这一篇我们正式把它讲透。

为什么 Flash Attention 值得单独一篇?因为它代表了深度学习系统优化的一个里程碑思路——它没有改变任何数学(计算结果完全等价),只通过重新设计数据在显存中的搬运方式,把 attention 的速度提升了 2-4 倍,把显存占用从 O(n²) 降到 O(n)。

如果你做过相关工作,下面这些问题应该不陌生:

  • 为什么 vLLM、SGLang、TensorRT-LLM 都默认用 Flash Attention?

  • 为什么把attn_implementation="flash_attention_2"加上模型就能跑得快很多?

  • Flash Attention 的"分块"、"在线 softmax"到底是什么?

  • H100 的 Flash Attention v3 比 v2 快多少?

  • 端到端训练用了 Flash Attention 后能省多少显存?

读完本文你将能:

  1. 理解 GPU 显存层级(HBM vs SRAM)—— 这是 Flash Attention 的物理基础

  2. 理解 Flash Attention 的两个核心技巧(Tiling + Online Softmax)

  3. 知道 v1 / v2 / v3 之间的演进,针对你的硬件选对版本

  4. 用 PyTorch / HuggingFace / vLLM 三种方式启用 Flash Attention

  5. 判断什么场景 Flash Attention 不适用

我们开始。


一、为什么 Attention 需要专门优化

1.1 一个被忽视的事实:GPU 不是只有算力

很多人对 GPU 的认知停留在「TFLOPS 多少」——比如 H100 SXM 是 989 TFLOPS(FP16)。

但 GPU 还有一个同等重要的指标:显存带宽

GPU

FP16 算力

显存带宽

算力/带宽比

V100

125 TFLOPS

0.9 TB/s

139

A100 80G

312 TFLOPS

2.0 TB/s

156

H100 SXM

989 TFLOPS

3.35 TB/s

295

H200

989 TFLOPS

4.8 TB/s

206

B200

2250 TFLOPS

8 TB/s

281

注意「算力/带宽比」——越高表示单位带宽对应的算力越多

关键认知

GPU 算力增长比显存带宽增长快得多。
从 V100 到 H100,算力翻了 8 倍,带宽只翻了 3.7 倍。
这意味着「IO 瓶颈」越来越严重。

1.2 GPU 显存的三层结构

我们再深入一层——GPU 内部其实有多级存储

HBM (High Bandwidth Memory) 80 GB, 3.3 TB/s ↑ ↑ "显存",所有数据默认在这里 非常大,相对慢 L2 Cache 50 MB, ~12 TB/s ↑ 中间层 SRAM (Shared Memory + Registers) ~228 KB / SM ~19 TB/s ↑ "片上",极快但极小 H100 有 132 个 SM,总共也才 30 MB

简化版:

[ 80 GB ] HBM ← 慢 ↕↕↕↕↕↕ 数据搬运 [ 30 MB ] SRAM ← 极快 ↑ 计算实际发生的地方

核心矛盾

  • 数据默认在 HBM(80 GB 富余)

  • 但计算必须在 SRAM 进行

  • 每次计算都要把数据从 HBM 搬到 SRAM

  • HBM 带宽(3.3 TB/s)远低于 SRAM(19 TB/s)

这就是为什么 IO 成了瓶颈——GPU 算力再强,数据搬不进来也没用。

1.3 传统 Attention 的 IO 噩梦

回顾 attention 计算:

S = Q · K^T # [n, n] 矩阵 P = softmax(S) # [n, n] 矩阵 O = P · V # [n, d] 矩阵

朴素实现把每一步的中间结果写回 HBM,然后下一步再读回来:

1. 读 Q, K 到 SRAM 2. 计算 S = QK^T 3. 把 S 写回 HBM ← O(n²) 写 4. 读 S 回 SRAM 5. 计算 P = softmax(S) 6. 把 P 写回 HBM ← O(n²) 写 7. 读 P, V 到 SRAM 8. 计算 O = PV 9. 写 O 到 HBM

问题:第 3、6 步要把n × n大小的矩阵在 HBM 和 SRAM 之间来回搬。

对于 n=8K 序列:

  • S 矩阵显存:8K × 8K × 4 bytes =256 MB

  • 这 256 MB 反复在 HBM ↔ SRAM 间来回搬

实测数据(A100 上 attention 计算):

  • 实际算力消耗:约 5% 的 GPU 算力

  • 实际 IO 消耗:约 95% 的 GPU 时间

也就是说,95% 的时间在搬数据,5% 的时间在算——这是工程优化的巨大空间。

1.4 Flash Attention 的「Aha Moment」

Flash Attention 论文(Tri Dao, 2022)的一句话总结了它的核心思想:

能不能让 attention 计算"不要"物化中间矩阵 S 和 P?

如果可以,那么:

  • IO 量从 O(n²) 降到 O(n)

  • 显存占用从 O(n²) 降到 O(n)

  • 速度提升 2-4×(算力终于能跑满)

但难点在于:softmax 需要看到整行才能归一化——你不知道总和之前,怎么知道每个元素的归一化值?

Flash Attention 的天才之处在于:它用一种叫"在线 softmax"的算法,让 softmax 可以流式计算


二、Flash Attention v1 原理深入

2.1 核心技巧 1:Tiling(分块)

Flash Attention 不一次计算整个 attention,而是按块计算。

把 Q、K、V 切成 block:

Q : [n, d] → 切成 Tr 块,每块 [Br, d] K : [n, d] → 切成 Tc 块,每块 [Bc, d] V : [n, d] → 切成 Tc 块,每块 [Bc, d]

Br、Bc 设计成能装进 SRAM(典型值 128)。

然后双层循环

for j in range(Tc): # 外层循环 K, V 块 把 Kj, Vj 加载到 SRAM for i in range(Tr): # 内层循环 Q 块 把 Qi 加载到 SRAM 在 SRAM 中计算 Qi · Kj^T → Sij (小矩阵) 在 SRAM 中应用 softmax → Pij 在 SRAM 中计算 Pij · Vj → 输出累积 把累积结果写回 HBM

关键

  • 整个n × n大矩阵 S 从未物化在 HBM

  • 只有小的Br × Bc块在 SRAM 里

  • HBM ↔ SRAM 的数据搬运量从 O(n²) 降到 O(n²/M)(M 是 SRAM 大小)

2.2 核心技巧 2:在线 Softmax

但 softmax 是个全局操作——它需要先看到整行才能归一化:

softmax(x) = exp(x_i) / Σ exp(x_j) ↑ 需要总和!

Flash Attention 用在线 softmax解决:

# 增量计算 softmax # 假设我们已经处理了前 i 个 block # m_i = 前 i 个 block 的最大值 # s_i = 前 i 个 block 的 exp 总和 新来一个 block,计算它的 softmax: m_new = max(m_i, max(new_block)) s_new = exp(m_i - m_new) * s_i + exp(m_new - m_new) * sum(exp(new_block - m_new)) 输出 = 用 m_new 和 s_new 重新归一化所有已处理的部分

这个算法的核心数学技巧

exp(a) + exp(b) = exp(max) * [exp(a - max) + exp(b - max)] ↑ 防止 overflow + 可流式合并

直观上

  • 每个 block 自己算 softmax(用本地 max 防 overflow)

  • 处理完后保存 (max, sum) 两个状态

  • 来新 block 时,用两个 max 之间的"换算因子"调整之前的累积

这个算法数学上完全等价于一次性 softmax——没有任何精度损失。

2.3 完整伪代码

def flash_attention(Q, K, V): n, d = Q.shape M = SRAM_SIZE # SRAM 大小,约 100 KB Br, Bc = derive_block_size(M, d) # 通常 128 Tr, Tc = n // Br, n // Bc # 初始化输出和状态 O = zeros((n, d), in_hbm=True) l = zeros(n, in_hbm=True) # 累积的 sum m = full(n, -inf, in_hbm=True) # 累积的 max for j inrange(Tc): Kj = load_to_sram(K[j*Bc:(j+1)*Bc]) Vj = load_to_sram(V[j*Bc:(j+1)*Bc]) for i inrange(Tr): Qi = load_to_sram(Q[i*Br:(i+1)*Br]) Oi = load_to_sram(O[i*Br:(i+1)*Br]) li = load_to_sram(l[i*Br:(i+1)*Br]) mi = load_to_sram(m[i*Br:(i+1)*Br]) # 在 SRAM 内计算 Sij = Qi @ Kj.T / sqrt(d) # [Br, Bc] mij = row_max(Sij) # [Br] Pij = exp(Sij - mij[:, None]) # [Br, Bc] lij = row_sum(Pij) # [Br] # 在线 softmax 合并 m_new = max(mi, mij) l_new = exp(mi - m_new) * li + exp(mij - m_new) * lij # 更新输出 Oi_new = ( (li * exp(mi - m_new))[:, None] * Oi + exp(mij - m_new)[:, None] * (Pij @ Vj) ) / l_new[:, None] # 写回 HBM write_to_hbm(O[i*Br:(i+1)*Br], Oi_new) write_to_hbm(l[i*Br:(i+1)*Br], l_new) write_to_hbm(m[i*Br:(i+1)*Br], m_new) return O

整体效果

  • 数学等价于标准 attention

  • 中间矩阵 S, P 从未离开过 SRAM

  • HBM IO 量降为原来的 1/M(M = SRAM 大小,约 100 KB)

2.4 Flash Attention v1 的实际收益

序列长度

朴素 Attention

Flash Attention

速度提升

512

1.2×

1.2×

1024

1.8×

1.8×

4096

2.7×

2.7×

16384

3.5×

3.5×

结论:序列越长,Flash Attention 越赚。

显存占用:

序列长度

朴素 (n² 矩阵)

Flash Attention

8K

256 MB

< 1 MB

32K

4 GB

< 4 MB

128K

64 GB

< 16 MB

这就是为什么没有 Flash Attention 根本搞不动长上下文——光 attention 矩阵就把显存吃光了。


三、Flash Attention v2 / v3 的演进

3.1 v2 (2023.07):进一步加速

Flash Attention v2 的改进点:

改进 1:减少非矩阵乘法的开销

v1 中有不少 "rescale"、"max compare" 等非 matmul 操作,这些操作虽然简单但累积起来不少。v2 重新设计算法,把它们减少到最少。

改进 2:更好的并行化

v1 内层循环只在 Q 上并行。v2 把外层循环也并行化,更充分利用 GPU 的多个 SM

改进 3:分配更好的 warp

把 SRAM 分配给更细粒度的 warp,进一步提升计算密度。

实测

  • 比 v1 快~2×

  • 在 A100 上达到 50-70% 的理论算力

  • 在长序列下尤其明显

3.2 v3 (2024.07):H100 时代的飞跃

Flash Attention v3 专为 H100 设计,引入了 H100 的特殊功能:

特性 1:异步加载(async TMA)

H100 引入了TMA(Tensor Memory Accelerator)——可以异步搬运数据,让计算和数据搬运 overlap。

v3 充分利用这个:

计算 block 1 ── 同时加载 block 2 计算 block 2 ── 同时加载 block 3 计算 block 3 ── 同时加载 block 4 ...
特性 2:FP8 支持

v3 第一次支持 FP8 attention:

  • • 精度:约 0.1% 掉点

  • • 速度:比 FP16 再快 2×

特性 3:Warpgroup 异步矩阵乘法

H100 的WGMMA(Warpgroup MMA)让矩阵乘法本身就是异步的。v3 充分利用这个,让算力打满。

实测

  • v3 在 H100 上达到75% 的理论算力(vs v2 的 35%)

  • FP8 模式下接近1.5 PFLOPS

3.3 三个版本性能对比

测试设置:H100 SXM,序列长度 8K,d=128,BF16:

版本

TFLOPS

利用率

标准 PyTorch

11

1.1%

Flash v1

195

19.7%

Flash v2

348

35.2%

Flash v3

740

74.8%

Flash v3 FP8

1417

71.6% (vs FP8 ceil)

结论

  • 标准 PyTorch → Flash v3:67× 加速

  • v2 → v3:约2× 加速(H100 专属)

3.4 哪个版本配哪个硬件

GPU

推荐 Flash 版本

V100 / T4

v1(v2/v3 不一定支持)

A100 / L40

v2

(v3 部分支持但优化不到位)

H100 / H200

v3

B200

v3(v4 即将出,专门为 Blackwell 优化)


四、工程实战:怎么用上 Flash Attention

4.1 用 HuggingFace Transformers 自动启用

最简单的方式:

from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen3-32B-Instruct", attn_implementation="flash_attention_2", # ← 关键 torch_dtype="auto", device_map="auto", )

支持的选项:

attn_implementation = "eager" # 朴素,PyTorch 实现,慢 attn_implementation = "sdpa" # PyTorch 2.0 内置,使用 backend attn_implementation = "flash_attention_2" # Flash v2 attn_implementation = "flash_attention_3" # Flash v3 (Transformers 4.46+ 支持)

Tip

  • "sdpa"是 PyTorch 内置的scaled_dot_product_attention,它在底层会自动选择 Flash 或 Memory-Efficient 实现——很多情况下这就够用

  • "flash_attention_2"/"flash_attention_3"需要pip install flash-attn

4.2 用 PyTorch SDPA(最通用)

PyTorch 2.0+ 内置了scaled_dot_product_attention,会自动用 Flash Attention 后端

import torch.nn.functional as F def my_attention(q, k, v, mask=None): # 自动用 Flash Attention 如果可用 output = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=True, ) return output

控制后端:

from torch.nn.attention import SDPBackend, sdpa_kernel with sdpa_kernel(SDPBackend.FLASH_ATTENTION): output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

可选 backend:

  • FLASH_ATTENTION── Flash Attention 实现

  • EFFICIENT_ATTENTION── Memory Efficient Attention

  • MATH── 标准实现(fallback)

  • CUDNN_ATTENTION── cuDNN 实现(新)

4.3 在 vLLM 中

vLLM默认就用 Flash Attention,你什么都不用做:

vllm serve Qwen/Qwen3-32B-Instruct # 自动用 Flash Attention v2/v3(看硬件)

强制版本:

# vLLM 0.6+ 支持 VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve ... VLLM_ATTENTION_BACKEND=FLASH_ATTN_3 vllm serve ...

4.4 训练时的 Flash Attention

训练阶段 Flash Attention 收益更明显——因为序列更长、需要反向传播。

from transformers import AutoModelForCausalLM, TrainingArguments model = AutoModelForCausalLM.from_pretrained( "...", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, ) training_args = TrainingArguments( ..., bf16=True, # 必须用低精度才能用 Flash Attention gradient_checkpointing=True, # 配合用,节省更多显存 )

实测

  • 训练 70B 模型 + 8K context

  • 不开 Flash Attention:每 step 4.2 秒,显存 78 GB

  • 开 Flash Attention v2:每 step 1.8 秒,显存 42 GB

2× 加速 + 47% 显存节省。这就是为什么训练大模型必须用 Flash Attention。

4.5 安装

# Flash Attention v2 pip install flash-attn --no-build-isolation # Flash Attention v3(H100 only,目前仍在 hopper 分支) pip install git+https://github.com/Dao-AILab/flash-attention.git@hopper

常见安装坑

  • CUDA 版本要匹配(建议 12.x+)

  • 编译时间长(首次安装 30-60 分钟)

  • 需要 ≥ 8 GB 内存编译

  • 没有预编译 wheel 时编译失败 → 安装 ninja 试试


五、扩展话题:Flash 家族还在演进

5.1 Flash Decoding(推理专用)

Flash Attention v2 主要为训练优化(长 seq、batch 大)。推理有不同的瓶颈:

  • Decode 阶段每次只处理 1 个 token

  • KV Cache 上的 attention 是 1 × N 矩阵(不是 N × N)

  • 真正的瓶颈是并行度不足

Flash Decoding(Dao 2023.10)专门解决这个:

  • 把 KV 序列也切到多个 SM上并行

  • 每个 SM 处理 KV 的一部分

  • 最后用 log-sum-exp 合并

效果

  • Decode 速度提升2-8×(看 batch 和 seq)

  • 长上下文场景尤其明显(128K decode 提升 5×+)

当下地位vLLM、SGLang 等推理框架都已集成

5.2 Ring Attention(跨卡 Flash)

第 5 篇我们讲过 Ring Attention——它本质上就是 Flash Attention 的分布式版本

  • 把 KV 切到多张卡

  • 每张卡持有局部 KV

  • KV 在卡间环形传递,每张卡轮流和其他卡的 KV 做 Flash Attention

这是训练 / 推理 1M+ 上下文的基础。

5.3 Triton 实现

Flash Attention 原生用 CUDA 写,但Triton 版本越来越流行

  • Triton 是 OpenAI 开源的 GPU kernel DSL

  • 比 CUDA 简单

  • 性能接近 CUDA(v2 大概 90%,v3 仍在追赶)

  • 可读性极强——你可以读 Triton 版的 Flash Attention 来理解算法

vLLM 部分 backend 就是 Triton 实现。

5.4 什么时候 Flash Attention 不适用

虽然 Flash Attention 是"标配",但有一些场景不适用收益有限

场景

原因

序列极短(< 256)

IO 占比不大,传统 attention 反而更快

自定义 attention(如 ALiBi 老版)

Flash 默认不支持任意 mask,要专门修改

FP32 训练

Flash v1/v2 仅支持 FP16/BF16,v3 加 FP8

老 GPU(Pascal / Volta)

Flash 需要 Ampere+ 架构

极特殊 attention 模式(局部 + 全局混合)

需要专门定制

但 95% 的场景,Flash Attention 都是无脑选项。


六、Flash Attention 给工程师的启示

6.1 算法 + 硬件 = 真正的优化

Flash Attention 的成功不是算法创新(softmax 还是那个 softmax),也不是新硬件(GPU 没变),而是两者结合

  • 理解算法的数学结构

  • 理解硬件的物理特性

  • 重新设计两者的接口

这是大模型系统优化的核心方法论:不要只看算法,也不要只看硬件,而是两者协同

6.2 IO 优化的普适性

Flash Attention 的"分块 + 流式合并"思路在很多地方都能用:

  • 量化:W4 + FP16 也用类似思想分块

  • MoE:专家计算和数据搬运的 overlap

  • 分布式训练:通信和计算的 overlap

  • 训练 checkpointing:分块保存激活

如果你做系统优化,多想想"能不能不物化中间结果"——这是个屡试不爽的优化方向。

6.3 不要害怕底层

Flash Attention 的实现要写 CUDA / Triton kernel,这让很多工程师望而却步。但理解它的原理并不要求你能从零写——理解 Tiling、在线 softmax、IO/compute 平衡这些概念,已经足够你做正确的部署决策。


七、结语:Flash Attention 是大模型时代的基础设施

读完本文你应该明白:

  • GPU 算力增长比带宽快,IO 是大模型的主要瓶颈

  • Flash Attention 用 Tiling + Online Softmax 把 attention IO 量从 O(n²) 降到 O(n)

  • v1 / v2 / v3 演进:v1 开创、v2 优化、v3 适配 H100 + FP8

  • 使用方式:HuggingFace 加attn_implementation、PyTorch 用SDPA、vLLM 默认启用

  • 训练比推理收益更大:训练 70B + 8K 上下文,2× 加速 + 47% 显存节省

  • Flash Decoding / Ring Attention:Flash 家族在持续演进

参考文献:

13.Flash Attention 原理与实践:让 Attention 重新成为算力游戏

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/7/1 7:26:08

Rust的闭包中的接口回调

Rust的闭包与接口回调&#xff1a;灵活与高效的结合 在现代编程中&#xff0c;回调机制是处理异步逻辑和事件驱动编程的核心工具之一。Rust作为一门注重安全与性能的系统级语言&#xff0c;其闭包特性为接口回调提供了强大且灵活的支持。闭包不仅可以捕获环境变量&#xff0c;…

作者头像 李华
网站建设 2026/7/1 7:25:39

Codex CLI 服务器无痕运行教程:API Key 不落盘,退出即清理

Linux 终端临时运行 Codex CLI&#xff1a;不写配置、不保存历史、退出自动清理 前言 在服务器或容器中使用 Codex CLI 时&#xff0c;有时不希望执行全局安装&#xff0c;也不希望 API Key、配置文件、npm 缓存和 Codex 会话长期保存在系统中。 本文介绍一种临时运行方案&a…

作者头像 李华
网站建设 2026/7/1 7:23:08

Hirebotics推出无代码防爆协作机器人,专为工业喷涂设计

协作机器人解决方案提供商Hirebotics宣布推出首款防爆协作机器人解决方案——Cobot Painter&#xff0c;该产品基于其无代码Beacon平台&#xff0c;并搭载发那科CRX-10iA/L Paint硬件构建而成。Cobot Painter目前已正式上市&#xff0c;为金属制造商提供一种灵活实用的高混合、…

作者头像 李华
网站建设 2026/7/1 7:21:18

保姆级教程:用CANoe 17.2.88的Easy实例,5分钟搞懂汽车总线数据模拟

零基础5分钟实战&#xff1a;用CANoe Easy实例解锁汽车总线模拟第一次打开CANoe软件时&#xff0c;满屏的英文界面和专业术语确实容易让人望而生畏。作为汽车电子领域最常用的总线开发测试工具&#xff0c;CANoe的强大功能背后是陡峭的学习曲线。但别担心&#xff0c;Vector官方…

作者头像 李华
网站建设 2026/7/1 7:21:06

NTN卫星通信实战:手把手教你理解SSB波束配置与R17协议限制

NTN卫星通信实战&#xff1a;SSB波束配置与R17协议限制深度解析当卫星通信遇上5G NR协议&#xff0c;SSB波束配置成为系统设计中最关键的参数之一。对于参与NTN项目的工程师而言&#xff0c;理解不同频段下SSB波束数量与子载波间隔的配置逻辑&#xff0c;以及R17协议64个波束限…

作者头像 李华