Ulysses序列并行突破百万级上下文训练瓶颈,显著降低显存需求

5 阅读5分钟前沿
Ulysses序列并行突破百万级上下文训练瓶颈,显著降低显存需求

背景与挑战

随着生成式模型在文档理解、代码分析和长文本推理等任务中的应用日益增多,单条样本的 Token 长度常常突破数十万甚至上百万。传统的 Transformer 注意力计算随序列长度呈二次增长,导致显存和计算成本急剧上升,即便使用 FlashAttention 将显存降至 O(n) 仍难支撑超长序列。常规的数据并行无法解决序列维度的瓶颈,因为每张卡仍需完整加载整个序列。

Ulysses 工作原理

Ulysses(DeepSpeed 提出的 Sequence Parallelism)通过 序列切分 + 注意力头划分 的双向并行实现长序列训练。核心步骤包括:

  1. 序列切片:将长度为 n 的输入在 P 张 GPU 上等分,每张卡持有 n/P 个 Token。
  2. QKV 投影:各卡分别对本地切片进行 Query、Key、Value 计算。
  3. 全局 all‑to‑all:一次 all‑to‑all 将所有卡的 QKV 按注意力头重新分配,使每张卡拥有完整序列但仅负责部分头的计算。
  4. 局部注意力:在本卡上使用 FlashAttention 或原生 SDPA 完成分配到的头的注意力计算。
  5. 逆向 all‑to‑all:再次全互换恢复序列切片布局,完成输出投影。

该方案的通信量为 O(n·d/P),相比 Ring Attention 的 O(n·d) 下降了 P 倍,并且一次 all‑to‑all 能够利用 NVLink 的全带宽,时延更低。

与 Hugging Face 生态的深度集成

  • Accelerate:提供 ParallelismConfigDeepSpeedSequenceParallelConfig,只需在 Accelerator 初始化时声明 sp_sizesp_backend="deepspeed" 等参数,即可自动完成序列切片、数据加载器适配以及损失加权聚合。
  • Transformers Trainer:在 TrainingArguments 中传入相同的并行配置,Trainer 会自行封装 dataloader、处理 shift_labels 并完成跨卡 loss 汇总,无需手写分布式代码。
  • TRL SFTTrainer:基于 Trainer 再进一步优化监督微调流程,支持可变序列长度、自动 padding 对齐(pad_to_multiple_of 必须等于 sp_size),并可结合 Liger‑Kernel、TiledMLP 等内存友好算子。

性能与显存收益

在 H100 80GB GPU 上对 Qwen3‑4B 进行 96K Token 训练的实验显示:

  • 显存:单卡峰值从 22 GB(DP 基线)提升至 66 GB,仍在显卡容量范围内,实现 12 倍序列长度扩展。
  • 吞吐:64K Token 时每秒处理约 13 k Token,约为基线的 3.7 倍。
  • 损失一致性:在相同 token 预算下,Ulysses 与传统 DP 的归一化交叉熵误差相差不到 0.01%,证明分布式拆分不影响模型收敛。

选型建议

  • 当模型的 注意力头数 ≥ sp_size 时优先使用 Ulysses;若头数受限或已有 Ring Attention 实现,可保留后者。
  • 长序列训练必须保证 序列长度可被 sp_size 整除,否则会产生额外 padding。
  • 推荐配合 FlashAttention 2/3 以及 DeepSpeed ZeRO‑3,在显存紧张时开启参数/优化器 CPU offload。

未来展望

Ulysses 为千兆级上下文训练奠定了系统基础,随着 GPU 互联带宽提升和更高阶的 All‑to‑All 优化,进一步提升 256K‑512K Token 规模将成为可能。这将直接推动全书级别的文档检索、跨章节推理以及大规模代码库理解等应用落地。

作者注:本文基于 Snowflake AI Research 公布的 ALST(Arctic Long Sequence Training)协议以及 DeepSpeed 官方文档撰写,所有实验均在公开可复现的配置下完成。

本文是对第三方新闻源的主观解读。消息可能出现过时、不准确、歧义或错误的地方,仅供参考使用。点击此处查看消息源。