从单卡到多卡:分布式训练原理与 FSDP 实战

2026-05-06 · Steve Chan

从单卡到多卡:分布式训练原理与 FSDP 实战

本文以 CodeChat 项目为例,解释为什么 8B 模型必须多卡训练,以及 FSDP 是怎么把一个放不下的模型"切"到 8 张卡上的。


1. 单卡训练:一切从这里开始

单卡训练最简单。一张 GPU 上放着三样东西:

┌─────────────────── GPU (80GB) ───────────────────┐
│  模型参数 (params)                                 │
│  梯度 (gradients)                                  │
│  优化器状态 (optimizer states)                      │
│  激活值 (activations,前向传播的中间结果)             │
└──────────────────────────────────────────────────┘

显存都花在哪了?

以 8B 参数模型为例(8.3B params):

组件 计算方式 显存
参数 (bf16) 8.3B × 2 bytes ~16 GB
梯度 (bf16) 8.3B × 2 bytes ~16 GB
AdamW 一阶动量 m (fp32) 8.3B × 4 bytes ~32 GB
AdamW 二阶动量 v (fp32) 8.3B × 4 bytes ~32 GB
合计(不含激活) ~96 GB

一张 A800 只有 80GB,光是模型参数 + 优化器状态就需要 96GB,还没算激活值。

这就是为什么 CodeChat 2B 能单卡训练,而 8B 不行。

CodeChat 2B 单卡的显存预算

# scripts/base_train.py(单卡模式)
# 2B 模型:
#   参数 bf16:    ~5 GB
#   梯度 bf16:    ~5 GB
#   Adam m+v fp32: ~20 GB
#   激活(开启 grad_checkpoint): ~10-20 GB
#   总计: ~40-50 GB  ← 80GB 卡放得下

2. 为什么不能简单地"多卡并行"?

2.1 数据并行 (Data Parallelism / DDP)

最朴素的多卡方案:每张卡上放一份完整的模型,各自处理不同的数据,然后同步梯度。

DDP: 每张卡都有完整的一份

GPU 0: [全部参数] + [全部梯度] + [全部优化器状态]  ← 还是 96GB
GPU 1: [全部参数] + [全部梯度] + [全部优化器状态]  ← 还是 96GB
...
GPU 7: [全部参数] + [全部梯度] + [全部优化器状态]  ← 还是 96GB

DDP 解决了速度问题(8 张卡处理 8 倍数据),但没有解决显存问题。 每张卡仍然需要放下完整的 96GB,单卡 80GB 还是不够。

PyTorch 中 DDP 的核心操作是 AllReduce——所有卡在反向传播后同步梯度的平均值:

反向传播后:
  GPU 0 梯度: [g0_0, g0_1, g0_2, ...]
  GPU 1 梯度: [g1_0, g1_1, g1_2, ...]
  ...

AllReduce 后 (每张卡都拿到平均梯度):
  GPU 0 梯度: [avg_0, avg_1, avg_2, ...]
  GPU 1 梯度: [avg_0, avg_1, avg_2, ...]   ← 全部相同
  ...

2.2 朴素模型并行 (Naive Model Parallelism)

把模型按层切开,不同层放在不同卡上:

GPU 0: Layer 0-9    →  前向传播  →  传给 GPU 1
GPU 1: Layer 10-19  →  前向传播  →  传给 GPU 2
GPU 2: Layer 20-29  →  前向传播  →  传给 GPU 3
GPU 3: Layer 30-39  →  前向传播  →  输出

问题:流水线气泡 (pipeline bubble)。当 GPU 3 在算前向时,GPU 0/1/2 全在等着,利用率极低。

2.3 真正的需求

我们需要一种方案,既切分显存(解决放不下的问题),又让所有卡都在干活(保持高利用率)

这就是 FSDP 要解决的问题。


3. FSDP:把参数、梯度、优化器全部切碎

FSDP = Fully Sharded Data Parallel,是 PyTorch 原生的零冗余并行方案,等价于 DeepSpeed ZeRO-3。

核心思想:不再让每张卡持有完整模型,而是把参数/梯度/优化器状态均匀切片(shard)到所有卡上。需要用的时候临时聚合(all-gather),用完立刻丢弃。

3.1 三个层级的切分

微软 DeepSpeed 论文把零冗余优化分成三个阶段:

阶段 切什么 每卡显存 (8B, 8 卡) 等价
ZeRO-1 仅优化器状态 ~24 GB (params+grads 仍完整)
ZeRO-2 优化器 + 梯度 ~20 GB FSDP SHARD_GRAD_OP
ZeRO-3 优化器 + 梯度 + 参数 ~12 GB FSDP FULL_SHARD

CodeChat 使用 FULL_SHARD(ZeRO-3),三样全切:

# scripts/base_train.py / chat_sft.py / chat_rl.py
ShardingStrategy.FULL_SHARD   # params + grads + optim 全部按 rank 切片

3.2 FSDP 的运行流程

以一个 forward → backward → optimizer step 为例:

═══════════════════════════════════════════════════════
  FORWARD PASS (前向传播)
═══════════════════════════════════════════════════════

平时每张卡只持有 1/8 的参数:

  GPU 0: [shard_0]     (参数的第 0 片,~2GB)
  GPU 1: [shard_1]     (参数的第 1 片,~2GB)
  ...
  GPU 7: [shard_7]     (参数的第 7 片,~2GB)

当需要计算某一层时,FSDP 临时执行 All-Gather:

  ┌─── All-Gather ────────────────────────────────┐
  │ GPU 0 广播 shard_0  ──→  所有卡               │
  │ GPU 1 广播 shard_1  ──→  所有卡               │
  │ ...                                            │
  │ GPU 7 广播 shard_7  ──→  所有卡               │
  │                                                │
  │ 结果: 每张卡临时持有这一层的完整参数             │
  └────────────────────────────────────────────────┘

  → 用完整参数做前向计算
  → 计算完毕后立即丢弃非本 rank 的参数 (释放显存)
  → 进入下一层,重复 All-Gather

═══════════════════════════════════════════════════════
  BACKWARD PASS (反向传播)
═══════════════════════════════════════════════════════

  → 再次 All-Gather 该层参数 (前向时已丢弃)
  → 计算该层梯度
  → Reduce-Scatter: 每张卡只保留属于自己那片的梯度
  → 释放完整梯度,只留 1/8

  ┌─── Reduce-Scatter ────────────────────────────┐
  │ 把梯度切成 8 片,每张卡拿到自己那片的求和结果   │
  │                                                │
  │ GPU 0: grad_shard_0  (对应 param_shard_0)      │
  │ GPU 1: grad_shard_1  (对应 param_shard_1)      │
  │ ...                                            │
  └────────────────────────────────────────────────┘

═══════════════════════════════════════════════════════
  OPTIMIZER STEP (优化器更新)
═══════════════════════════════════════════════════════

  每张卡只更新自己的 1/8:

  GPU 0: Adam(param_shard_0, grad_shard_0, m_shard_0, v_shard_0)
  GPU 1: Adam(param_shard_1, grad_shard_1, m_shard_1, v_shard_1)
  ...

  无需通信!每张卡独立更新自己的碎片。

3.3 显存对比

8B 模型,8×A800:

DDP (每卡完整副本):
  参数:    16 GB
  梯度:    16 GB
  Adam:    64 GB
  ─────────────
  总计:    96 GB  ← 超出 80GB,OOM!

FSDP FULL_SHARD (每卡 1/8):
  参数:     2 GB  (16/8)
  梯度:     2 GB  (16/8)
  Adam:     8 GB  (64/8)
  All-Gather 临时缓冲:  ~2-4 GB (一层的完整参数)
  激活值:   ~10-20 GB
  ─────────────
  总计:    ~24-36 GB  ← 80GB 轻松放下

4. FSDP 的通信机制

4.1 两种核心通信原语

All-Gather: 每张卡贡献自己的碎片,所有卡拿到完整结果
┌───┐ ┌───┐ ┌───┐ ┌───┐
│ A │ │ B │ │ C │ │ D │   (每卡一片)
└───┘ └───┘ └───┘ └───┘
          ↓ All-Gather
┌───────────────────────┐
│ A │ B │ C │ D │         (每卡都拿到全部)
└───────────────────────┘ × 4 份

Reduce-Scatter: 先求和再分发,每张卡拿到一片的求和结果
GPU 0: [a0, a1, a2, a3]
GPU 1: [b0, b1, b2, b3]
GPU 2: [c0, c1, c2, c3]
GPU 3: [d0, d1, d2, d3]
          ↓ Reduce-Scatter
GPU 0: [a0+b0+c0+d0]         (第 0 片的总和)
GPU 1: [a1+b1+c1+d1]         (第 1 片的总和)
GPU 2: [a2+b2+c2+d2]         (第 2 片的总和)
GPU 3: [a3+b3+c3+d3]         (第 3 片的总和)

A800 之间通过 NVLink 互联,带宽 ~400 GB/s(双向),远高于 PCIe 的 ~32 GB/s。

FSDP 中 All-Gather 一层 8B 模型的参数量大约 400MB,在 NVLink 上只需 ~1ms。这使得"用时聚合、用完丢弃"的策略在实际中几乎不增加额外开销。

4.3 通信与计算重叠

FSDP 的一个关键优化是 backward prefetch:在计算第 N 层反向传播的同时,提前 All-Gather 第 N-1 层的参数。

# CodeChat 中的配置
backward_prefetch=BackwardPrefetch.BACKWARD_PRE  # 提前预取
时间线:
  ┌── GPU 计算 ──┐┌── GPU 计算 ──┐
  │ Layer N 反向  ││ Layer N-1 反向│
  └──────────────┘└──────────────┘
  ┌── NVLink ────────┐
  │ All-Gather N-1   │  ← 与 Layer N 计算并行
  └──────────────────┘

5. FSDP 的切分粒度:Auto Wrap Policy

FSDP 不是把整个模型当作一块来切,而是按 transformer block (层) 为单位包装。每个 Block 是一个独立的 FSDP 单元。

# CodeChat 的包装策略
auto_wrap = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={Block},   # ← 每个 Block 单独包一层 FSDP
)

为什么按 Block 切?

不切 (整个模型一个 FSDP 单元):
  All-Gather 整个模型 → 临时需要 16GB → 显存峰值太高

按 Block 切 (40 个 FSDP 单元):
  每次 All-Gather 一个 Block → 临时需要 ~400MB → 显存峰值可控

模型结构与 FSDP 包装的对应关系:

GPT (顶层 FSDP)
├── wte (embedding)           ← 跟顶层一起切
├── Block 0  (FSDP 单元 0)    ← 独立切片
│   ├── LayerNorm
│   ├── MultiHeadAttention
│   ├── LayerNorm
│   └── MLP
├── Block 1  (FSDP 单元 1)    ← 独立切片
│   └── ...
├── ...
├── Block 39 (FSDP 单元 39)   ← 独立切片
├── ln_f (final norm)          ← 跟顶层一起切
└── lm_head                    ← 跟顶层一起切

6. 混合精度训练 (Mixed Precision)

FSDP 支持三种粒度的精度控制:

mp_policy = MixedPrecision(
    param_dtype=COMPUTE_DTYPE,    # bf16: 参数在 GPU 上以 bf16 存储和计算
    reduce_dtype=COMPUTE_DTYPE,   # bf16: 梯度同步时用 bf16 通信
    buffer_dtype=COMPUTE_DTYPE,   # bf16: BatchNorm 等 buffer 也用 bf16
)
组件 精度 原因
参数 (前向/反向) bf16 节省显存和计算,A800 有 bf16 Tensor Core
梯度通信 bf16 减少 NVLink 传输量
AdamW 状态 (m, v) fp32 优化器精度必须高,否则训练不稳定
优化器更新 fp32 小的梯度更新在 bf16 下会被截断为 0

这就是为什么 Adam 状态占最多显存——它必须用 fp32:

bf16 参数:      8.3B × 2 bytes = 16 GB
fp32 Adam m:    8.3B × 4 bytes = 32 GB  ← 是参数的 2 倍!
fp32 Adam v:    8.3B × 4 bytes = 32 GB

7. CodeChat 从单卡到多卡的实际改动

7.1 启动方式

# 单卡 (2B)
python -m scripts.base_train --preset 2b ...

# 多卡 (8B, 8×A800)
torchrun --nproc_per_node=8 -m scripts.base_train --preset 8b ...

torchrun 会启动 8 个进程,每个进程设置 LOCAL_RANK=0..7 环境变量。

7.2 分布式初始化

def setup_distributed():
    if "LOCAL_RANK" not in os.environ:
        return False, 0, 0, 1            # 单卡模式,什么都不做
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)     # 每个进程绑定自己的 GPU
    dist.init_process_group(backend="nccl")  # 初始化 NCCL 通信后端
    return True, local_rank, dist.get_rank(), dist.get_world_size()

7.3 Checkpoint 加载

# 错误做法 (所有进程同时加载到 GPU 0):
state = load_ckpt(path)  # 默认 map_location="cuda" → 全部挤到 GPU 0

# 正确做法 (先加载到 CPU,再分发):
state = load_ckpt(path, map_location="cpu")  # CPU 内存充足
model.load_state_dict(state["model"])
del state  # 释放 CPU 内存
model = model.to(device)  # 移到各自的 GPU

7.4 FSDP 包装

model = model.to(device).to(COMPUTE_DTYPE)
if is_dist:
    model = wrap_fsdp(model)  # 参数被切片到各卡

# 之后 model 的行为对外不变:
#   logits, loss = model(x, y)   ← 内部自动 All-Gather + Reduce-Scatter

7.5 梯度同步优化:no_sync

当使用梯度累积时,只有最后一个 micro-batch 需要同步梯度:

for micro in range(grad_accum):
    x, y = loader.next_batch()
    if is_dist and micro < grad_accum - 1:
        with model.no_sync():          # ← 跳过中间步的梯度同步
            _, loss = model(x, y)
            (loss / grad_accum).backward()
    else:                              # ← 最后一步才真正同步
        _, loss = model(x, y)
        (loss / grad_accum).backward()

没有 no_sync 的话,每个 micro-batch 都会触发 Reduce-Scatter,白白浪费 7 次通信。

7.6 梯度裁剪

# 单卡:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

# FSDP: 必须用 FSDP 自己的方法,因为参数是分片的
grad_norm = model.clip_grad_norm_(1.0)

FSDP 的 clip_grad_norm_ 内部会先 All-Gather 梯度范数,计算全局范数后再裁剪。

7.7 Checkpoint 保存

# codechat/checkpoint.py
if _is_fsdp(model):
    # FSDP 模型的参数分散在各卡,需要先聚合
    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, ...):
        model_state = model.state_dict()  # 在 rank 0 上聚合完整参数
    # 只有 rank 0 写文件

7.8 只在 rank 0 做 I/O

writer = SummaryWriter(log_dir=tb_path) if is_master else None

if is_master:
    print(f"step {step} | loss {loss:.4f}")
    writer.add_scalar("train/loss", loss, step)

如果 8 个进程都写 TensorBoard,日志会重复 8 份且可能互相覆盖。


8. RL 阶段的特殊处理

RL (GRPO) 比 SFT 更复杂,因为它需要采样——自回归生成 token。在 FSDP 下有两个额外的同步点:

8.1 问题索引同步

每个 step 随机选一个 MBPP 问题。8 个进程必须选同一个:

idx = torch.randint(0, len(mbpp), (1,), device=device)
if is_dist:
    dist.broadcast(idx, src=0)  # rank 0 的选择广播给所有 rank

8.2 采样 token 同步

自回归采样时,torch.multinomial 在不同 GPU 上可能产生不同结果(CUDA RNG 差异)。所有 rank 必须生成相同的 token 序列,否则后续的 forward 输入不一致会导致 FSDP 死锁。

nxt = torch.multinomial(probs, 1)     # 各 rank 可能不同
if is_dist:
    dist.broadcast(nxt, src=0)         # 强制所有 rank 用 rank 0 的结果

9. 各种并行策略对比

策略 切什么 适用场景 通信量 CodeChat 使用
DDP 不切,同步梯度 模型放得下单卡 低 (AllReduce 梯度) 2B 可用
FSDP SHARD_GRAD_OP (ZeRO-2) 梯度 + 优化器 参数放得下但优化器放不下
FSDP FULL_SHARD (ZeRO-3) 参数 + 梯度 + 优化器 模型放不下单卡 高 (每层 All-Gather) 8B 使用
Tensor Parallel 切矩阵乘法 超大模型 (70B+) 每层内通信
Pipeline Parallel 切层到不同卡 超大模型 + 多节点 层间传递激活

实际中大模型训练常常组合使用多种策略(如 Llama 3 用了 FSDP + Tensor Parallel + Pipeline Parallel)。CodeChat 8B 只需 FSDP 就够了。


10. 总结

单卡 2B:   什么都不用想,model.to("cuda") 就行
            显存 ~40GB,一张 A800 足够

多卡 8B:   DDP 不够(每卡仍需 96GB)
            必须用 FSDP FULL_SHARD(每卡 ~24GB)
            代价:每层前向/反向各一次 All-Gather + Reduce-Scatter
            NVLink 高带宽 (~400 GB/s) 使通信开销可接受

代码改动:   5 处核心改动
            1. setup_distributed() — 初始化 NCCL
            2. map_location="cpu" — checkpoint 先载 CPU
            3. wrap_fsdp(model) — FSDP 包装
            4. model.no_sync() — 梯度累积优化
            5. model.clip_grad_norm_() — FSDP 感知的梯度裁剪

FSDP 的精髓就一句话:用通信换显存,让 N 张卡合力装下一个本来放不下的模型,同时保持接近 N 倍的训练速度。