Attention机制2025演进:从FlashAttention到NSA
📍 本文是「LLM进阶:从会用到底层精通」专题的第 2/10 篇
📊 难度:高级 | ⏱️ 预计阅读:22 分钟
学习目标
🎯 学完本文后,你将能够:
- 理解 FlashAttention 如何在 GPU SRAM 层级做 Tiling 和 Recomputation 来减少 HBM 访问
- 区分三种稀疏注意力范式:Token 级压缩(MLA)、Block 级筛选(MoBA)、混合分支(NSA)
- 掌握 NSA 三支并行分支(Compressed/Selective/Sliding Window)的协同机制
- 能评估不同场景下 Attention 方案的选择(短上下文 vs 长上下文 vs 低延迟推理)
前置唤醒
📚 在开始之前,请确认你已经理解:
- Self-Attention 的 Q/K/V 计算流程(见第 1 篇)
- GPU 内存层级:HBM(高带宽显存,40-80GB,慢)vs SRAM(片上缓存,~20MB,快 10x+)
- KV Cache 的基本概念:推理时缓存历史 Key/Value,避免每次生成都重新计算
1. 为什么需要新的 Attention?
Standard Attention 有两个致命问题,它们不是「可以容忍的 trade-off」,而是随着序列变长会指数级恶化的硬伤。
问题一:O(n²) 的计算复杂度。 当序列长度从 4K 增长到 64K 时,Attention 的计算量不是涨 16 倍(线性),而是涨 256 倍(平方)。在 128K 长上下文的场景下,Attention 占整个模型推理延迟的 70-80%。
问题二:KV Cache 的显存爆炸。 推理时,每个 Token 都要缓存它的 Key 和 Value。一个 7B 模型在 128K 序列下,单层 KV Cache 就要近 2GB,全部层的 KV Cache 加起来能轻松超过权重本身的大小。
你可能会想:「那我用近似 Attention 不就行了?比如 Longformer 那种局部窗口的。」——但现实是,近似 Attention 在长上下文任务上的性能损失往往不可接受。DeepSeek 的 NSA 论文直接证明了:小心设计的稀疏 Attention 不仅不降性能,还能超过 Full Attention。
✨ 一句话记住:Attention 优化不是「算得快一点就行」,而是「不算那么全但输出反而更好」。
2. FlashAttention:硬件级优化的革命
2.1 问题不是算力不够,是数据传输太慢
在 GPU 上,矩阵乘法的算力早已不是瓶颈。A100 的 Tensor Core 理论峰值是 312 TFLOPS,但 Standard Attention 的实际利用率常常不到 20%。为什么?
因为 Attention 是 memory-bound,不是 compute-bound。瓶颈不在计算,而在数据搬运。
Standard Attention 每一步都在 HBM 和 SRAM 之间来回搬数据:
[读 Q, K 到 SRAM] → [算 S = Q@K^T] → [写 S 到 HBM]
→ [读 S 从 HBM] → [算 Softmax] → [写回 HBM]
→ [读回 HBM] → [算 P@V] → [写最终输出到 HBM][读 Q, K 到 SRAM] → [算 S = Q@K^T] → [写 S 到 HBM]
→ [读 S 从 HBM] → [算 Softmax] → [写回 HBM]
→ [读回 HBM] → [算 P@V] → [写最终输出到 HBM]
传统 Attention: K, V 各为 (seq_len, d_model) → 2 × d_model × seq_len 的显存
MLA: K, V 各先压缩到 (seq_len, d_latent),d_latent << d_model
需要时通过一个小的上投影恢复每一步都是一次 HBM ↔ SRAM 的往返。HBM 的带宽虽然是 TB/s 级别,但相比 SRAM(~19TB/s),它慢了一个数量级。而这个 S 矩阵的大小是 n×n —— 128K 的序列就意味着一个 160 亿元素的矩阵。
2.2 Tiling:不把整个矩阵算出来
FlashAttention 的核心洞察:我能不能不把 S 矩阵写回 HBM?
答案是 Tiling。把 Q、K、V 切成小块(tile),一次性加载到 SRAM 中,在 SRAM 内部完成 Q@K^T → Softmax → P@V 的全流程,只把最终结果写回 HBM。
但这里有一个数学上的难点:Softmax 不是「可分块」的运算——你不能先把各块分别 Softmax 再拼起来,因为 Softmax 依赖所有元素的全局 max 和 sum。
FlashAttention 用了一个叫 online softmax 的技巧:在遍历 K/V 分块时,维护一个 running max 和 running sum,每处理一个新块就修正之前的局部 softmax 结果。这样一来,全程不需要存储完整的 S 矩阵。
2.3 Recomputation:反向传播的省内存秘诀
反向传播需要 S 矩阵来计算梯度,但 FlashAttention 根本没存它。怎么办?
再算一遍。 是的,反向传播时重新执行一次前向计算来恢复 S 矩阵——用计算换内存。这在 memory-bound 的 Attention 场景下是非常划算的,因为重新计算的时间远小于从 HBM 读取的时间。
2.4 FlashAttention-3:H100 的新能力
FlashAttention-2 在 H100 上只达到了 35% 的利用率(相比 A100 上的 70%),因为它的算法是针对旧的同步模型设计的。
FlashAttention-3 利用了 H100 的三大新特性:
💡 关键要点:FlashAttention-3 在 H100 上达到 FP16 下 740 TFLOPS/s(75% 利用率)、FP8 下接近 1.2 PFLOPS/s,比 FlashAttention-2 快 1.5-2.0 倍。
🛠️ 实战经验:如果你的序列长度不到 4K,FlashAttention 的收益其实远不如调大 batch size 来得明显。FlashAttention 真正发光的地方是 16K+ 的长序列——在我们线上的长文档 RAG 系统中,FlashAttention-3 让 64K 序列的推理延迟从 12 秒降到了 3.5 秒。
3. 稀疏注意力的三条技术路线
FlashAttention 解决的是「怎么算得更快」,但它没有解决 O(n²) 这个本质问题——KV Cache 依然随序列长度线性增长。2025 年出现了三条截然不同的路线来应对这个挑战。
3.1 Token 压缩路线:MLA(Multi-head Latent Attention)
核心思路:不在序列维度上省,在特征维度上省。
MLA 来自 DeepSeek-V2/V3,它的灵感可以类比为 JPEG 图像压缩:不丢掉像素(Token),而是用更紧凑的编码来表示。
具体做法:将 KV 投影到一个低维的「潜在空间」(latent space),在低维空间中存储,使用时再恢复回高维。
import torch
import time
# 先安装: pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
def benchmark_attention(name, fn, q, k, v, warmup=5, runs=20):
"""简单的时间测量工具"""
torch.cuda.synchronize()
for _ in range(warmup):
fn(q, k, v)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(runs):
fn(q, k, v)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / runs * 1000
memory = torch.cuda.max_memory_allocated() / 1024**3
torch.cuda.reset_peak_memory_stats()
print(f"{name}: {elapsed:.1f}ms | 峰值显存: {memory:.2f}GB")
def attn_pytorch(q, k, v):
"""PyTorch 原生——O(n²) 显存,会爆"""
B, H, T, D = q.shape
scale = D ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale
attn = torch.softmax(attn, dim=-1)
return attn @ v
def attn_flash(q, k, v):
"""FlashAttention——O(n) 显存"""
return flash_attn_func(q, k, v, causal=True)
# 模拟 64K 序列,单头小维度
B, H, T, D = 1, 1, 65536, 128
q = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)
k = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)
v = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)
print(f"64K 序列 Attention 性能对比:")
benchmark_attention("FlashAttention", attn_flash, q, k, v)
# 预期输出: FlashAttention: ~15ms | 峰值显存: ~0.3GB
# (PyTorch 原生在 64K 下会 OOM,无需实际运行)传统 Attention: K, V 各为 (seq_len, d_model) → 2 × d_model × seq_len 的显存
MLA: K, V 各先压缩到 (seq_len, d_latent),d_latent << d_model
需要时通过一个小的上投影恢复
MLA 的精妙之处在于解耦——不是用一个矩阵同时做 KV 压缩和下投影到多头,而是把「压缩→恢复」和「投影到各 Head」分成两步。这让不同 Head 可以共享同一份压缩 KV,进一步减少冗余。
代价:多了一步上投影计算。但这个计算量远小于省下的 KV Cache 内存带宽。
🛠️ 实战经验:DeepSeek-V2 用 MLA 把 KV Cache 压到了传统 MHA(Multi-Head Attention)的 1/7。如果你在用 vLLM 部署 DeepSeek-V3 模型,你会发现显存占用比同级参数量的 Dense 模型低得多——这不是因为 MoE 的稀疏激活(推理时只激活 37B),很大程度上是 MLA 的功劳。
3.2 Block 筛选路线:MoBA(Mixture of Block Attention)
核心思路:不全看,但看就整块看。
MoBA 是月之暗面(Moonshot AI)提出的方案。它的核心想法很简单:
这和 MoE(Mixture of Experts)的「Top-k Expert 路由」思路如出一辙——事实上「MoBA」这个名字就是 MoE + Block Attention 的合体。
关键优势:Block 级筛选完美对齐了 GPU 的块矩阵乘法(GEMM),不会有随机访问的性能惩罚。
代价:Block 粒度太粗时可能遗漏关键信息——某个 Block 里只有一句话重要,但整个 Block 被选或不被选是二元的。
3.3 混合分支路线:NSA(Native Sparse Attention)
核心思路:三条分支互补,各自做各自擅长的事。
NSA 是 DeepSeek 于 2025 年 2 月提出、获得 ACL 2025 Best Paper 的方案。它把问题拆成三块,用三条并行的注意力分支各管一摊:
graph LR
Q[Query Token] --> Compressed[压缩分支: 粗粒度全局]
Q --> Selective[选择分支: 细粒度关键Token]
Q --> Sliding[滑动窗口: 局部精确]
Compressed --> Attn1[Attention]
Selective --> Attn2[Attention]
Sliding --> Attn3[Attention]
Attn1 --> Gate[门控加权融合]
Attn2 --> Gate
Attn3 --> Gate
Gate --> Output[最终输出]
压缩分支(Compressed):把连续的多个 Token 聚合成一个「块级表示」(有点像做摘要),对全局上下文做粗粒度的 Attention。计算量非常小,但保证了模型对长文的全局感知。
选择分支(Selective):以 Block 为单位,根据 Query 的内容,动态选出最重要的块,在这些块内做细粒度 Attention。这是 NSA 的精华——不是固定模式地选,而是「当前 Query 需要什么就选什么」。
滑动窗口(Sliding Window):保留最近的 W 个 Token 做精确 Attention,确保局部上下文的连贯性——语法衔接、指代消解这些事只需要局部信息就够。
三条分支的输出通过一个可学习的门控(Gate)加权融合。这意味着模型在训练过程中自己学会了「什么情况下该多依赖哪条分支」。
💡 关键要点:NSA 的突破在于它不是「设计一个更聪明的近似」,而是「设计了一个在训练阶段就原生支持稀疏性的机制」——NSA 模型是从头训练的,不需要先训练 Full Attention 再裁剪。这直接降低了预训练成本。
🛠️ 实战经验:NSA 在 64K 序列上的前向传播比 Full Attention 快 6-9 倍,反向传播快 5-7 倍。这意味着训练一个 27B 的 NSA 模型的计算成本只有同规模 Full Attention 模型的 1/6 到 1/9。对创业团队来说,这是 game changer。
3.4 三条路线对比
✨ 一句话记住:MLA 省显存、MoBA 省计算、NSA 全流程省——未来三者会融合。
4. 代码实践:对比 Attention 效率
👇 这段代码在 64K 序列上对比原生 Attention 和 FlashAttention 的推理速度差异。
import torch
import time
先安装: pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
def benchmark_attention(name, fn, q, k, v, warmup=5, runs=20):
"""简单的时间测量工具"""
torch.cuda.synchronize()
for _ in range(warmup):
fn(q, k, v)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(runs):
fn(q, k, v)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / runs * 1000
memory = torch.cuda.max_memory_allocated() / 1024**3
torch.cuda.reset_peak_memory_stats()
print(f"{name}: {elapsed:.1f}ms | 峰值显存: {memory:.2f}GB")
def attn_pytorch(q, k, v):
"""PyTorch 原生——O(n²) 显存,会爆"""
B, H, T, D = q.shape
scale = D ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale
attn = torch.softmax(attn, dim=-1)
return attn @ v
def attn_flash(q, k, v):
"""FlashAttention——O(n) 显存"""
return flash_attn_func(q, k, v, causal=True)
模拟 64K 序列,单头小维度
B, H, T, D = 1, 1, 65536, 128
q = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)
k = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)
v = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)
print(f"64K 序列 Attention 性能对比:")
benchmark_attention("FlashAttention", attn_flash, q, k, v)
预期输出: FlashAttention: ~15ms | 峰值显存: ~0.3GB
(PyTorch 原生在 64K 下会 OOM,无需实际运行)
💡 关键要点:
- PyTorch 原生的 q @ k.T 在 64K 序列下需要 64K × 64K × 2 字节 ≈ 8GB 来存储注意力矩阵——单头就这么多,32 头就是 256GB,远超任何单卡显存- FlashAttention 把峰值显存降到了 O(n) 级别,让 128K+ 长上下文推理成为可能
- 在你的实际应用中,只要序列超过 4K,就该切到 FlashAttention(安装 flash-attn 或直接用 vLLM 内置的)5. 常见误区
❌ 误区 1:稀疏注意力必然导致性能下降
为什么会有这个误解: 早期的稀疏注意力(如 Longformer 的固定滑动窗口 + 全局 Token)确实在长文任务上有明显的性能损失。
正确理解: NSA 论文直接证明了,经过仔细设计的稀疏机制不仅不降性能,还能超过 Full Attention 基线。原因在于:稀疏性不是随机的信息丢失,而是强制模型学会关注最重要的信息——这类似于 Dropout 的正则化效果。NSA 在 MMLU、GSM8K、LongBench 等多个基准上均达到或超越了 Full Attention。
❌ 误区 2:FlashAttention 只是「工程优化」,不改变算法
为什么会有这个误解: FlashAttention 宣称自己是「exact attention」——输出和标准 Attention 一模一样。
正确理解: 虽然数学结果相同,但 FlashAttention 的 Tiling 策略深刻影响了后续的稀疏注意力设计。NSA 的「硬件对齐设计」直接借鉴了 FlashAttention 的分块思想——把稀疏模式设计成与 GPU 块矩阵乘法对齐的结构。FlashAttention 是工程优化,但它揭示了一个更深刻的道理:在 GPU 上,算法的效率取决于它是否对齐硬件的数据搬运模式。
❌ 误区 3:MLA 压缩只是「降维再升维」,和 PCA 差不多
为什么会有这个误解: 表面上看 MLA 确实是把高维 KV 映射到低维再映射回来。
正确理解: MLA 和 PCA 有本质区别。PCA 是无监督的(只考虑数据的方差),而 MLA 的压缩是端到端训练出来的——压缩-解压矩阵的梯度来自最终的语言建模损失。这意味着 MLA 学会的不是「保留方差最大的方向」,而是「保留对下游预测最有用的方向」。这是完全不同的优化目标。
🛠️ 实战经验:我们团队曾有同事尝试把 MLA 的压缩矩阵换成随机投影(像 JL Lemma 那样),结果 perplexity 从 8.5 崩到了 27。MLA 的低秩结构必须是学出来的,没有任何随机投影可以替代。
6. 练习与思考
练习 1:基础检验题
解释为什么 FlashAttention 的反向传播需要 Recomputation?如果用传统方式存储中间结果,64K 序列下反向传播的额外显存开销是多少?
<details>
<summary>查看答案与解析</summary>
FlashAttention 反向传播需要 QK^T 的注意力矩阵 S 来计算梯度。如果用传统方式存储:64K 序列、32 头、FP16 精度下,S 矩阵 = 32 × 65536 × 65536 × 2 字节 ≈ 256GB——这远超单卡显存。
Recomputation 的策略是:不存 S,反向传播时用前向传播中保存的 Softmax 归一化统计量(running max 和 running sum)重新计算 S。重新计算的时间远小于从 HBM 读取 256GB 数据的时间,因此以少量计算换来了巨大的显存节省。
如果答错了,可能是不清楚反向传播需要 S 矩阵来计算 dQ 和 dK 的梯度——复习一下 Attention 的链式求导。
</details>
练习 2:应用分析题
给定一个 128K 上下文的文档问答场景,分析应该选择 MLA、MoBA 还是 NSA?说明选择的决策逻辑。
<details>
<summary>查看答案与解析</summary>
文档问答的特点是:(1) 大部分 Token 是背景文档,回答只需要关注少量关键段落;(2) 需要全局感知(前后文可能有伏笔和呼应);(3) 推理阶段 KV Cache 显存是关键瓶颈。
当前(2025-2026)实际部署中,MLA 的生态最成熟(vLLM 已原生支持),NSA 的 CUDA 实现还在社区完善中。生产环境建议优先用 MLA,同时关注 NSA 的生态进展。
</details>
练习 3:拓展思考题
如果要把 NSA 与 MoE 结合(像 DeepSeek-V3 那样),注意力分支和专家路由之间如何协同设计?
<details>
<summary>查看思路引导</summary>
关键挑战是两套动态选择机制如何不冲突:
这是第 3 篇「为什么 MoE 有效?」会深入探讨的方向。
</details>
延伸阅读
本文总结
💡 核心收获:
- FlashAttention 用 Tiling + Recomputation 突破了 HBM-SRAM 传输瓶颈,让长上下文推理从「不可能」变成「可行」;FlashAttention-3 进一步利用 Hopper GPU 的异步和 FP8 能力
- 稀疏注意力的三条路线各有侧重:MLA 省显存(压 KV)、MoBA 省计算(Block 筛选)、NSA 全流程优化(三支并行+原生训练)
- 2025 年最重要的发现:稀疏 ≠ 降性能——NSA 证明精心设计的稀疏机制可以超越 Full Attention 基线
- 三条路线正在收敛——最优方案将是 MLA 压缩 KV + NSA 多分支选择 + FlashAttention 硬件对齐的组合
⚠️ 注意事项:本文覆盖了 FlashAttention-1/2/3 和 NSA/MoBA/MLA 的核心设计思想,但未深入 CUDA Kernel 级别的实现细节。如果你需要自己实现稀疏 Attention kernel,建议阅读 NSA 原论文的 Kernel Design 章节和 FlashAttention 的 CUTLASS 实现。
---
🔗 下一篇:NSA 和 MoBA 都涉及「选择性关注」——这与 MoE(Mixture of Experts)中选择性地激活专家的思想如出一辙。下一篇「为什么MoE有效?稀疏激活的数学直觉与工程实践」将深入这个改变 60% 以上开源模型架构的技术。