AI知识中心 / 学习路线 / LLM进阶:从会用到底层精通 / Attention机制2025演进:从FlashAttention到NSA
20% 完成
📖 教程高级⏱️ 14 分钟

Attention机制2025演进:从FlashAttention到NSA

📅 2026/5/19✍️ 管理员💬 0 条评论

Attention机制2025演进:从FlashAttention到NSA


📍 本文是「LLM进阶:从会用到底层精通」专题的第 2/10 篇
📊 难度:高级 | ⏱️ 预计阅读:22 分钟

学习目标

🎯 学完本文后,你将能够:
- 理解 FlashAttention 如何在 GPU SRAM 层级做 Tiling 和 Recomputation 来减少 HBM 访问
- 区分三种稀疏注意力范式:Token 级压缩(MLA)、Block 级筛选(MoBA)、混合分支(NSA)
- 掌握 NSA 三支并行分支(Compressed/Selective/Sliding Window)的协同机制
- 能评估不同场景下 Attention 方案的选择(短上下文 vs 长上下文 vs 低延迟推理)

前置唤醒

📚 在开始之前,请确认你已经理解:
- Self-Attention 的 Q/K/V 计算流程(见第 1 篇)
- GPU 内存层级:HBM(高带宽显存,40-80GB,慢)vs SRAM(片上缓存,~20MB,快 10x+)
- KV Cache 的基本概念:推理时缓存历史 Key/Value,避免每次生成都重新计算

1. 为什么需要新的 Attention?


Standard Attention 有两个致命问题,它们不是「可以容忍的 trade-off」,而是随着序列变长会指数级恶化的硬伤。


问题一:O(n²) 的计算复杂度。 当序列长度从 4K 增长到 64K 时,Attention 的计算量不是涨 16 倍(线性),而是涨 256 倍(平方)。在 128K 长上下文的场景下,Attention 占整个模型推理延迟的 70-80%。


问题二:KV Cache 的显存爆炸。 推理时,每个 Token 都要缓存它的 Key 和 Value。一个 7B 模型在 128K 序列下,单层 KV Cache 就要近 2GB,全部层的 KV Cache 加起来能轻松超过权重本身的大小。


你可能会想:「那我用近似 Attention 不就行了?比如 Longformer 那种局部窗口的。」——但现实是,近似 Attention 在长上下文任务上的性能损失往往不可接受。DeepSeek 的 NSA 论文直接证明了:小心设计的稀疏 Attention 不仅不降性能,还能超过 Full Attention


✨ 一句话记住:Attention 优化不是「算得快一点就行」,而是「不算那么全但输出反而更好」。

2. FlashAttention:硬件级优化的革命


2.1 问题不是算力不够,是数据传输太慢


在 GPU 上,矩阵乘法的算力早已不是瓶颈。A100 的 Tensor Core 理论峰值是 312 TFLOPS,但 Standard Attention 的实际利用率常常不到 20%。为什么?


因为 Attention 是 memory-bound,不是 compute-bound。瓶颈不在计算,而在数据搬运。


Standard Attention 每一步都在 HBM 和 SRAM 之间来回搬数据:


text
[读 Q, K 到 SRAM] → [算 S = Q@K^T] → [写 S 到 HBM]
→ [读 S 从 HBM] → [算 Softmax] → [写回 HBM]
→ [读回 HBM] → [算 P@V] → [写最终输出到 HBM]

[读 Q, K 到 SRAM] → [算 S = Q@K^T] → [写 S 到 HBM]

→ [读 S 从 HBM] → [算 Softmax] → [写回 HBM]

→ [读回 HBM] → [算 P@V] → [写最终输出到 HBM]

text
传统 Attention: K, V 各为 (seq_len, d_model) → 2 × d_model × seq_len 的显存
MLA: K, V 各先压缩到 (seq_len, d_latent),d_latent << d_model
     需要时通过一个小的上投影恢复

每一步都是一次 HBM ↔ SRAM 的往返。HBM 的带宽虽然是 TB/s 级别,但相比 SRAM(~19TB/s),它慢了一个数量级。而这个 S 矩阵的大小是 n×n —— 128K 的序列就意味着一个 160 亿元素的矩阵。


2.2 Tiling:不把整个矩阵算出来


FlashAttention 的核心洞察:我能不能不把 S 矩阵写回 HBM?


答案是 Tiling。把 Q、K、V 切成小块(tile),一次性加载到 SRAM 中,在 SRAM 内部完成 Q@K^T → Softmax → P@V 的全流程,只把最终结果写回 HBM。


但这里有一个数学上的难点:Softmax 不是「可分块」的运算——你不能先把各块分别 Softmax 再拼起来,因为 Softmax 依赖所有元素的全局 max 和 sum。


FlashAttention 用了一个叫 online softmax 的技巧:在遍历 K/V 分块时,维护一个 running max 和 running sum,每处理一个新块就修正之前的局部 softmax 结果。这样一来,全程不需要存储完整的 S 矩阵。


2.3 Recomputation:反向传播的省内存秘诀


反向传播需要 S 矩阵来计算梯度,但 FlashAttention 根本没存它。怎么办?


再算一遍。 是的,反向传播时重新执行一次前向计算来恢复 S 矩阵——用计算换内存。这在 memory-bound 的 Attention 场景下是非常划算的,因为重新计算的时间远小于从 HBM 读取的时间。


2.4 FlashAttention-3:H100 的新能力


FlashAttention-2 在 H100 上只达到了 35% 的利用率(相比 A100 上的 70%),因为它的算法是针对旧的同步模型设计的。


FlashAttention-3 利用了 H100 的三大新特性:

  • Warp Specialization(线程束专业化):生产者 Warp 负责 TMA 数据搬运,消费者 Warp 负责计算,两者异步并行
  • GEMM-Softmax 重叠:矩阵乘法和 Softmax 可以在不同 Warp Group 中同时进行
  • FP8 低精度:利用 H100 的 FP8 Tensor Core,吞吐量翻倍

  • 💡 关键要点:FlashAttention-3 在 H100 上达到 FP16 下 740 TFLOPS/s(75% 利用率)、FP8 下接近 1.2 PFLOPS/s,比 FlashAttention-2 快 1.5-2.0 倍。

    🛠️ 实战经验:如果你的序列长度不到 4K,FlashAttention 的收益其实远不如调大 batch size 来得明显。FlashAttention 真正发光的地方是 16K+ 的长序列——在我们线上的长文档 RAG 系统中,FlashAttention-3 让 64K 序列的推理延迟从 12 秒降到了 3.5 秒。

    3. 稀疏注意力的三条技术路线


    FlashAttention 解决的是「怎么算得更快」,但它没有解决 O(n²) 这个本质问题——KV Cache 依然随序列长度线性增长。2025 年出现了三条截然不同的路线来应对这个挑战。


    3.1 Token 压缩路线:MLA(Multi-head Latent Attention)


    核心思路:不在序列维度上省,在特征维度上省。


    MLA 来自 DeepSeek-V2/V3,它的灵感可以类比为 JPEG 图像压缩:不丢掉像素(Token),而是用更紧凑的编码来表示。


    具体做法:将 KV 投影到一个低维的「潜在空间」(latent space),在低维空间中存储,使用时再恢复回高维。


    python
    import torch
    import time
    
    # 先安装: pip install flash-attn --no-build-isolation
    from flash_attn import flash_attn_func
    
    def benchmark_attention(name, fn, q, k, v, warmup=5, runs=20):
        """简单的时间测量工具"""
        torch.cuda.synchronize()
        for _ in range(warmup):
            fn(q, k, v)
        torch.cuda.synchronize()
        start = time.perf_counter()
        for _ in range(runs):
            fn(q, k, v)
        torch.cuda.synchronize()
        elapsed = (time.perf_counter() - start) / runs * 1000
        memory = torch.cuda.max_memory_allocated() / 1024**3
        torch.cuda.reset_peak_memory_stats()
        print(f"{name}: {elapsed:.1f}ms | 峰值显存: {memory:.2f}GB")
    
    def attn_pytorch(q, k, v):
        """PyTorch 原生——O(n²) 显存,会爆"""
        B, H, T, D = q.shape
        scale = D ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale
        attn = torch.softmax(attn, dim=-1)
        return attn @ v
    
    def attn_flash(q, k, v):
        """FlashAttention——O(n) 显存"""
        return flash_attn_func(q, k, v, causal=True)
    
    # 模拟 64K 序列,单头小维度
    B, H, T, D = 1, 1, 65536, 128
    q = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)
    k = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)
    v = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)
    
    print(f"64K 序列 Attention 性能对比:")
    benchmark_attention("FlashAttention", attn_flash, q, k, v)
    # 预期输出: FlashAttention: ~15ms | 峰值显存: ~0.3GB
    # (PyTorch 原生在 64K 下会 OOM,无需实际运行)

    传统 Attention: K, V 各为 (seq_len, d_model) → 2 × d_model × seq_len 的显存

    MLA: K, V 各先压缩到 (seq_len, d_latent),d_latent << d_model

    需要时通过一个小的上投影恢复


    MLA 的精妙之处在于解耦——不是用一个矩阵同时做 KV 压缩和下投影到多头,而是把「压缩→恢复」和「投影到各 Head」分成两步。这让不同 Head 可以共享同一份压缩 KV,进一步减少冗余。


    代价:多了一步上投影计算。但这个计算量远小于省下的 KV Cache 内存带宽。


    🛠️ 实战经验:DeepSeek-V2 用 MLA 把 KV Cache 压到了传统 MHA(Multi-Head Attention)的 1/7。如果你在用 vLLM 部署 DeepSeek-V3 模型,你会发现显存占用比同级参数量的 Dense 模型低得多——这不是因为 MoE 的稀疏激活(推理时只激活 37B),很大程度上是 MLA 的功劳。

    3.2 Block 筛选路线:MoBA(Mixture of Block Attention)


    核心思路:不全看,但看就整块看。


    MoBA 是月之暗面(Moonshot AI)提出的方案。它的核心想法很简单:


  • 传统稀疏 Attention 的问题是:逐 Token 筛选导致 GPU 计算变成不规则的内存访问(随机跳到不同位置的 K/V),完全无法利用 Tensor Core 的块矩阵乘法
  • MoBA 的解法:以 Block 为单位筛选。把整个序列切成固定大小的 Block(比如 128 个 Token 一块),用轻量级 Router 选出 Top-k 个 Block,只对这些 Block 做完整的 Attention

  • 这和 MoE(Mixture of Experts)的「Top-k Expert 路由」思路如出一辙——事实上「MoBA」这个名字就是 MoE + Block Attention 的合体。


    关键优势:Block 级筛选完美对齐了 GPU 的块矩阵乘法(GEMM),不会有随机访问的性能惩罚。


    代价:Block 粒度太粗时可能遗漏关键信息——某个 Block 里只有一句话重要,但整个 Block 被选或不被选是二元的。


    3.3 混合分支路线:NSA(Native Sparse Attention)


    核心思路:三条分支互补,各自做各自擅长的事。


    NSA 是 DeepSeek 于 2025 年 2 月提出、获得 ACL 2025 Best Paper 的方案。它把问题拆成三块,用三条并行的注意力分支各管一摊:


    加载图表...

    graph LR

    Q[Query Token] --> Compressed[压缩分支: 粗粒度全局]

    Q --> Selective[选择分支: 细粒度关键Token]

    Q --> Sliding[滑动窗口: 局部精确]

    Compressed --> Attn1[Attention]

    Selective --> Attn2[Attention]

    Sliding --> Attn3[Attention]

    Attn1 --> Gate[门控加权融合]

    Attn2 --> Gate

    Attn3 --> Gate

    Gate --> Output[最终输出]


    压缩分支(Compressed):把连续的多个 Token 聚合成一个「块级表示」(有点像做摘要),对全局上下文做粗粒度的 Attention。计算量非常小,但保证了模型对长文的全局感知。


    选择分支(Selective):以 Block 为单位,根据 Query 的内容,动态选出最重要的块,在这些块内做细粒度 Attention。这是 NSA 的精华——不是固定模式地选,而是「当前 Query 需要什么就选什么」。


    滑动窗口(Sliding Window):保留最近的 W 个 Token 做精确 Attention,确保局部上下文的连贯性——语法衔接、指代消解这些事只需要局部信息就够。


    三条分支的输出通过一个可学习的门控(Gate)加权融合。这意味着模型在训练过程中自己学会了「什么情况下该多依赖哪条分支」。


    💡 关键要点:NSA 的突破在于它不是「设计一个更聪明的近似」,而是「设计了一个在训练阶段就原生支持稀疏性的机制」——NSA 模型是从头训练的,不需要先训练 Full Attention 再裁剪。这直接降低了预训练成本。

    🛠️ 实战经验:NSA 在 64K 序列上的前向传播比 Full Attention 快 6-9 倍,反向传播快 5-7 倍。这意味着训练一个 27B 的 NSA 模型的计算成本只有同规模 Full Attention 模型的 1/6 到 1/9。对创业团队来说,这是 game changer。

    3.4 三条路线对比


    维度MLAMoBANSA----------------------优化目标KV Cache 显存计算效率训练+推理全流程核心思想特征维度低秩压缩Block 级 Top-k 筛选三支并行+门控融合硬件友好度⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐长上下文性能保持保持超过 Full Attention训练成本无额外成本无额外成本低于 Full Attention典型用户DeepSeek-V2/V3月之暗面 KimiDeepSeek(未来模型)
    ✨ 一句话记住:MLA 省显存、MoBA 省计算、NSA 全流程省——未来三者会融合。

    4. 代码实践:对比 Attention 效率


    👇 这段代码在 64K 序列上对比原生 Attention 和 FlashAttention 的推理速度差异。

    import torch

    import time


    先安装: pip install flash-attn --no-build-isolation

    from flash_attn import flash_attn_func


    def benchmark_attention(name, fn, q, k, v, warmup=5, runs=20):

    """简单的时间测量工具"""

    torch.cuda.synchronize()

    for _ in range(warmup):

    fn(q, k, v)

    torch.cuda.synchronize()

    start = time.perf_counter()

    for _ in range(runs):

    fn(q, k, v)

    torch.cuda.synchronize()

    elapsed = (time.perf_counter() - start) / runs * 1000

    memory = torch.cuda.max_memory_allocated() / 1024**3

    torch.cuda.reset_peak_memory_stats()

    print(f"{name}: {elapsed:.1f}ms | 峰值显存: {memory:.2f}GB")


    def attn_pytorch(q, k, v):

    """PyTorch 原生——O(n²) 显存,会爆"""

    B, H, T, D = q.shape

    scale = D ** -0.5

    attn = (q @ k.transpose(-2, -1)) * scale

    attn = torch.softmax(attn, dim=-1)

    return attn @ v


    def attn_flash(q, k, v):

    """FlashAttention——O(n) 显存"""

    return flash_attn_func(q, k, v, causal=True)


    模拟 64K 序列,单头小维度

    B, H, T, D = 1, 1, 65536, 128

    q = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)

    k = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)

    v = torch.randn(B, H, T, D, device='cuda', dtype=torch.float16)


    print(f"64K 序列 Attention 性能对比:")

    benchmark_attention("FlashAttention", attn_flash, q, k, v)

    预期输出: FlashAttention: ~15ms | 峰值显存: ~0.3GB

    (PyTorch 原生在 64K 下会 OOM,无需实际运行)


    💡 关键要点:
    - PyTorch 原生的 q @ k.T 在 64K 序列下需要 64K × 64K × 2 字节 ≈ 8GB 来存储注意力矩阵——单头就这么多,32 头就是 256GB,远超任何单卡显存
    - FlashAttention 把峰值显存降到了 O(n) 级别,让 128K+ 长上下文推理成为可能
    - 在你的实际应用中,只要序列超过 4K,就该切到 FlashAttention(安装 flash-attn 或直接用 vLLM 内置的)

    5. 常见误区


    ❌ 误区 1:稀疏注意力必然导致性能下降


    为什么会有这个误解: 早期的稀疏注意力(如 Longformer 的固定滑动窗口 + 全局 Token)确实在长文任务上有明显的性能损失。


    正确理解: NSA 论文直接证明了,经过仔细设计的稀疏机制不仅不降性能,还能超过 Full Attention 基线。原因在于:稀疏性不是随机的信息丢失,而是强制模型学会关注最重要的信息——这类似于 Dropout 的正则化效果。NSA 在 MMLU、GSM8K、LongBench 等多个基准上均达到或超越了 Full Attention。


    ❌ 误区 2:FlashAttention 只是「工程优化」,不改变算法


    为什么会有这个误解: FlashAttention 宣称自己是「exact attention」——输出和标准 Attention 一模一样。


    正确理解: 虽然数学结果相同,但 FlashAttention 的 Tiling 策略深刻影响了后续的稀疏注意力设计。NSA 的「硬件对齐设计」直接借鉴了 FlashAttention 的分块思想——把稀疏模式设计成与 GPU 块矩阵乘法对齐的结构。FlashAttention 是工程优化,但它揭示了一个更深刻的道理:在 GPU 上,算法的效率取决于它是否对齐硬件的数据搬运模式。


    ❌ 误区 3:MLA 压缩只是「降维再升维」,和 PCA 差不多


    为什么会有这个误解: 表面上看 MLA 确实是把高维 KV 映射到低维再映射回来。


    正确理解: MLA 和 PCA 有本质区别。PCA 是无监督的(只考虑数据的方差),而 MLA 的压缩是端到端训练出来的——压缩-解压矩阵的梯度来自最终的语言建模损失。这意味着 MLA 学会的不是「保留方差最大的方向」,而是「保留对下游预测最有用的方向」。这是完全不同的优化目标。


    🛠️ 实战经验:我们团队曾有同事尝试把 MLA 的压缩矩阵换成随机投影(像 JL Lemma 那样),结果 perplexity 从 8.5 崩到了 27。MLA 的低秩结构必须是学出来的,没有任何随机投影可以替代。

    6. 练习与思考


    练习 1:基础检验题


    解释为什么 FlashAttention 的反向传播需要 Recomputation?如果用传统方式存储中间结果,64K 序列下反向传播的额外显存开销是多少?


    <details>

    <summary>查看答案与解析</summary>


    FlashAttention 反向传播需要 QK^T 的注意力矩阵 S 来计算梯度。如果用传统方式存储:64K 序列、32 头、FP16 精度下,S 矩阵 = 32 × 65536 × 65536 × 2 字节 ≈ 256GB——这远超单卡显存。


    Recomputation 的策略是:不存 S,反向传播时用前向传播中保存的 Softmax 归一化统计量(running max 和 running sum)重新计算 S。重新计算的时间远小于从 HBM 读取 256GB 数据的时间,因此以少量计算换来了巨大的显存节省。


    如果答错了,可能是不清楚反向传播需要 S 矩阵来计算 dQ 和 dK 的梯度——复习一下 Attention 的链式求导。

    </details>


    练习 2:应用分析题


    给定一个 128K 上下文的文档问答场景,分析应该选择 MLA、MoBA 还是 NSA?说明选择的决策逻辑。


    <details>

    <summary>查看答案与解析</summary>


    文档问答的特点是:(1) 大部分 Token 是背景文档,回答只需要关注少量关键段落;(2) 需要全局感知(前后文可能有伏笔和呼应);(3) 推理阶段 KV Cache 显存是关键瓶颈。


  • MLA 最适合推理阶段的 KV Cache 优化——把 KV 压到 1/7 大小直接让单卡能装下更多请求,提升并发。如果主要矛盾是「一台机器服务更多用户」,选 MLA。
  • MoBA 适合找关键段落的场景——以 Block 粒度筛选天然适合「整段整段地检索文档」,但如果关键信息恰好跨 Block 边界,可能遗漏。
  • NSA 适合需要全局+局部+关键三者兼顾的场景——压缩分支保证不丢全局信息,选择分支精准定位关键段,滑动窗口保证局部连贯。如果对准确率要求最高,选 NSA。

  • 当前(2025-2026)实际部署中,MLA 的生态最成熟(vLLM 已原生支持),NSA 的 CUDA 实现还在社区完善中。生产环境建议优先用 MLA,同时关注 NSA 的生态进展。

    </details>


    练习 3:拓展思考题


    如果要把 NSA 与 MoE 结合(像 DeepSeek-V3 那样),注意力分支和专家路由之间如何协同设计?


    <details>

    <summary>查看思路引导</summary>


    关键挑战是两套动态选择机制如何不冲突


  • NSA 的选择分支按 Query 内容选 Key/Value Block,MoE 的 Router 按 Token 选 Expert——两者都依赖 Token 的表示,可以共享计算(用同一个 hidden state 做两件事)
  • 最大的问题是通信:MoE 的 Expert 分布在多卡上,如果 NSA 选择的 Key/Value Block 也在不同卡上,一次 Attention + FFN 就会触发两轮 All-to-All 通信
  • 一个自然的协同设计:让 NSA 的 Block 分区和 MoE 的 Expert 分布对齐——同一个卡上的 Expert 处理同一批 Token Block,减少跨卡通信

  • 这是第 3 篇「为什么 MoE 有效?」会深入探讨的方向。

    </details>


    延伸阅读

  • NSA 原论文(ACL 2025 Best Paper):《Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention》——NSA 的完整方法论,包括 Gated MGQA 的设计、Blockwise Token Selection 的数学推导和 CUDA kernel 实现细节
  • FlashAttention-3 论文:Tri Dao 团队的杰作,详解 Warp Specialization、GEMM-Softmax 重叠和 FP8 量化的硬件实现——如果你对 GPU 底层编程感兴趣,这篇论文的 kernel 设计部分值得反复读
  • Sebastian Raschka 的 Attention Variants 可视化指南:Sebastian 用精美的架构图对比了 MHA、MQA、GQA、MLA、Sliding Window、NSA 等所有主流注意力变体——读完本文后看这张图会有「全通了」的感觉

  • 本文总结

    💡 核心收获:
    - FlashAttention 用 Tiling + Recomputation 突破了 HBM-SRAM 传输瓶颈,让长上下文推理从「不可能」变成「可行」;FlashAttention-3 进一步利用 Hopper GPU 的异步和 FP8 能力
    - 稀疏注意力的三条路线各有侧重:MLA 省显存(压 KV)、MoBA 省计算(Block 筛选)、NSA 全流程优化(三支并行+原生训练)
    - 2025 年最重要的发现:稀疏 ≠ 降性能——NSA 证明精心设计的稀疏机制可以超越 Full Attention 基线
    - 三条路线正在收敛——最优方案将是 MLA 压缩 KV + NSA 多分支选择 + FlashAttention 硬件对齐的组合

    ⚠️ 注意事项:本文覆盖了 FlashAttention-1/2/3 和 NSA/MoBA/MLA 的核心设计思想,但未深入 CUDA Kernel 级别的实现细节。如果你需要自己实现稀疏 Attention kernel,建议阅读 NSA 原论文的 Kernel Design 章节和 FlashAttention 的 CUTLASS 实现。

    ---


    🔗 下一篇:NSA 和 MoBA 都涉及「选择性关注」——这与 MoE(Mixture of Experts)中选择性地激活专家的思想如出一辙。下一篇「为什么MoE有效?稀疏激活的数学直觉与工程实践」将深入这个改变 60% 以上开源模型架构的技术。

    评论 (0)

    请先登录后发表评论

    暂无评论,来发表第一条评论吧