noao-chat-7-nonochat 多卡训练指南

12 分钟阅读时长

发布时间:

LLM 多卡训练完全指南

nonochat

本文档全面解析 nanochat 项目中的分布式训练(DDP),从环境准备到代码实现,让小白也能轻松理解和使用多卡训练。


目录

  1. 什么是多卡训练
  2. 为什么需要多卡训练
  3. 环境准备
  4. 多卡启动方式
  5. 代码实现详解
  6. 数据并行机制
  7. 分布式优化器
  8. 实战案例
  9. 常见问题

1. 什么是多卡训练

1.1 基本概念

多卡训练 = 数据并行 (Data Parallel)

简单来说:用多张 GPU 同时训练一个模型。

单卡训练:
GPU 0: [模型副本] → 处理 batch 1, 2, 3, 4, 5, 6, 7, 8
       时间: ████████

多卡训练 (8卡):
GPU 0: [模型副本] → 处理 batch 1
GPU 1: [模型副本] → 处理 batch 2
GPU 2: [模型副本] → 处理 batch 3
...
GPU 7: [模型副本] → 处理 batch 8
       时间: █
       速度提升: 8x (理论)

1.2 关键术语

术语英文解释示例
进程组Process Group所有参与训练的进程8 卡 = 8 个进程
世界大小World Size总共有多少个进程8 张卡 → world_size=8
全局排名Rank当前进程的全局编号0, 1, 2, …, 7
本地排名Local Rank当前节点内的进程编号单节点 = Rank
主进程Master ProcessRank 0,负责日志和保存ddp_rank == 0
DDPDistributedDataParallelPyTorch 分布式训练框架torch.nn.parallel.DDP

1.3 工作流程

初始化阶段:
1. 每个 GPU 加载相同的模型副本
2. 建立进程间通信(NCCL)
3. 同步所有进程

训练阶段:
每个步骤:
  1. 每个 GPU 处理不同的 mini-batch
  2. 各自计算梯度
  3. 通信:所有 GPU 的梯度取平均
  4. 所有 GPU 用相同的梯度更新模型
  5. 模型保持同步

结束阶段:
主进程保存模型检查点

2. 为什么需要多卡训练

2.1 训练加速

实际效果(nanochat 项目):

配置训练时间(Base, d=12)加速比
单卡 A100~16 小时1x
8卡 A100~2 小时~8x

为什么不是完美的 8x?

  • 通信开销(梯度同步)
  • 数据加载开销
  • 通常能达到 85-95% 的线性扩展

2.2 批次大小限制

问题: 单卡显存不够大

# 单卡场景
device_batch_size = 32  # 受限于单卡显存
total_batch_size = 32 * 2048 = 65,536 tokens

# 8卡场景
device_batch_size = 32  # 每卡相同
total_batch_size = 32 * 2048 * 8 = 524,288 tokens
# 更大的批次 → 更稳定的梯度 → 更好的收敛

2.3 实验效率

对比:

# 单卡训练 - 需要 16 小时
python -m scripts.base_train --depth=12

# 8卡训练 - 只需 2 小时
torchrun --nproc_per_node=8 -m scripts.base_train --depth=12

# 节省时间 = 可以做更多实验!

3. 环境准备

3.1 硬件要求

最低配置:

  • 2+ 张 NVIDIA GPU
  • GPU 之间有 NVLink 或 PCIe 连接
  • 推荐:相同型号的 GPU

本项目测试配置:

  • 8x NVIDIA A100 80GB
  • NVLink 互联

3.2 软件环境

必需组件:

# 1. PyTorch (支持分布式)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# 2. 验证 NCCL 后端
python -c "import torch; print(torch.distributed.is_nccl_available())"
# 应输出: True

# 3. 验证多卡可见
python -c "import torch; print(torch.cuda.device_count())"
# 应输出: 8 (如果你有 8 张卡)

3.3 环境变量

自动设置(torchrun 会处理):

RANK          # 全局进程编号: 0, 1, 2, ..., 7
LOCAL_RANK    # 本地进程编号: 0, 1, 2, ..., 7
WORLD_SIZE    # 总进程数: 8
MASTER_ADDR   # 主节点地址: localhost (单节点)
MASTER_PORT   # 主节点端口: 29500 (默认)

手动检查:

# 查看当前环境
echo $CUDA_VISIBLE_DEVICES  # 可见的 GPU: 0,1,2,3,4,5,6,7

4. 多卡启动方式

4.1 torchrun 命令详解

基本语法:

torchrun [torchrun参数] -m [模块名] [脚本参数]

常用参数:

torchrun \
    --standalone \              # 单节点模式
    --nproc_per_node=8 \       # 每个节点的进程数(GPU数)
    -m scripts.base_train \    # 要运行的 Python 模块
    --depth=12 \               # 脚本参数
    --device_batch_size=32

4.2 完整示例

训练 Base 模型

# 单卡(不使用 torchrun)
python -m scripts.base_train \
    --depth=12 \
    --device_batch_size=8 \
    --run=base_d12_single

# 8卡(使用 torchrun)
torchrun --standalone --nproc_per_node=8 \
    -m scripts.base_train \
    --depth=12 \
    --device_batch_size=32 \
    --run=base_d12_multi

训练 Mid 模型

torchrun --standalone --nproc_per_node=8 \
    -m scripts.mid_train \
    --source=base \
    --device_batch_size=16

训练 SFT 模型

torchrun --standalone --nproc_per_node=8 \
    -m scripts.chat_sft \
    --source=mid \
    --device_batch_size=8

训练 RL 模型

torchrun --standalone --nproc_per_node=8 \
    -m scripts.chat_rl \
    --source=sft \
    --device_batch_size=8 \
    --num_samples=16

4.3 参数对比

参数单卡多卡说明
启动方式python -mtorchrun多卡需要 torchrun
device_batch_size832多卡可以更大
梯度累积多卡减少累积步数
训练时间多卡显著加速

5. 代码实现详解

5.1 分布式初始化

位置: nanochat/common.py

def compute_init(device_type="cuda"):
    """基础初始化,包括分布式设置"""
    
    # 1. 获取分布式信息
    ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
    
    # 2. 如果是分布式 + CUDA
    if ddp and device_type == "cuda":
        # 为每个进程指定不同的 GPU
        device = torch.device("cuda", ddp_local_rank)
        torch.cuda.set_device(device)
        
        # 初始化进程组(使用 NCCL 后端)
        dist.init_process_group(
            backend="nccl",      # NVIDIA GPU 专用
            device_id=device
        )
        
        # 同步所有进程
        dist.barrier()
    else:
        device = torch.device(device_type)
    
    return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device

关键点:

  1. 进程绑定 GPU
    # Rank 0 → GPU 0
    # Rank 1 → GPU 1
    # ...
    # Rank 7 → GPU 7
    device = torch.device("cuda", ddp_local_rank)
    
  2. 初始化通信
    dist.init_process_group(backend="nccl")
    # 建立所有进程之间的通信通道
    
  3. 同步屏障
    dist.barrier()
    # 等待所有进程都到达这里
    

5.2 判断是否启用 DDP

def is_ddp():
    """检查是否在分布式环境中"""
    return int(os.environ.get('RANK', -1)) != -1

def get_dist_info():
    """获取分布式信息"""
    if is_ddp():
        # torchrun 会设置这些环境变量
        ddp_rank = int(os.environ['RANK'])
        ddp_local_rank = int(os.environ['LOCAL_RANK'])
        ddp_world_size = int(os.environ['WORLD_SIZE'])
        return True, ddp_rank, ddp_local_rank, ddp_world_size
    else:
        # 单卡模式
        return False, 0, 0, 1

5.3 主进程判断

为什么需要主进程?

  • 避免重复日志(8 个进程打印 8 遍)
  • 避免重复保存(8 个进程保存 8 次)
  • 避免重复评估(8 个进程评估 8 次)
# 方法 1: 直接判断
if ddp_rank == 0:
    print("只有主进程打印这条消息")
    save_checkpoint(model)

# 方法 2: 使用 print0 工具
print0("自动判断主进程的打印")

print0 实现:

def print0(s="", **kwargs):
    ddp_rank = int(os.environ.get('RANK', 0))
    if ddp_rank == 0:
        print(s, **kwargs)

5.4 模型包装

不需要显式包装!

nanochat 项目不使用 torch.nn.parallel.DistributedDataParallel 包装模型。

# 传统 DDP 做法(其他项目):
if ddp:
    model = DDP(model, device_ids=[local_rank])

# nanochat 做法:
# 不包装!直接使用原始模型
# 梯度同步在优化器中处理

为什么?

  • 使用自定义的分布式优化器(DistMuon, DistAdamW)
  • 优化器内部处理梯度通信
  • 更灵活、更高效

6. 数据并行机制

6.1 数据分配策略

核心思想: 每个进程处理不同的数据

# 循环分配(Round-Robin)
for example_idx in range(ddp_rank, len(dataset), ddp_world_size):
    process_example(example_idx)

# 示例:8 卡处理 24 个样本
# Rank 0: 处理 0, 8, 16
# Rank 1: 处理 1, 9, 17
# Rank 2: 处理 2, 10, 18
# ...
# Rank 7: 处理 7, 15, 23

6.2 Base 训练的数据加载

位置: nanochat/dataloader.py

def tokenizing_distributed_data_loader(B, T, split, device="cuda"):
    """分布式数据加载器"""
    
    ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
    
    # 读取 Parquet 文件
    parquet_paths = list_parquet_files()
    
    # 每个 rank 从不同的 row group 开始
    rg_idx = ddp_rank  # 起始索引 = rank
    
    while rg_idx < pf.num_row_groups:
        # 读取数据
        rg = pf.read_row_group(rg_idx)
        batch = rg.column('text').to_pylist()
        
        # ... tokenize ...
        
        # 跳到下一个 row group(间隔 = world_size)
        rg_idx += ddp_world_size  # 0, 8, 16, ...

可视化:

Row Groups (Parquet 文件):
[0] [1] [2] [3] [4] [5] [6] [7] [8] [9] [10] ...

8 卡分配:
Rank 0: [0]             [8]              [16] ...
Rank 1:     [1]             [9]              ...
Rank 2:         [2]             [10]         ...
...
Rank 7:                     [7]              ...

结果: 
- 所有 rank 处理不同的数据
- 不重复、不遗漏
- 负载均衡

6.3 SFT/RL 的数据加载

对话数据的分配:

# SFT 训练
conversations = load_conversations()
for i in range(ddp_rank, len(conversations), ddp_world_size):
    conversation = conversations[i]
    # 处理这个对话...

# RL 训练
questions = load_questions()
for idx in range(ddp_rank, len(questions), ddp_world_size):
    question = questions[idx]
    # 采样 16 个答案...

7. 分布式优化器

7.1 为什么需要分布式优化器

问题: 每个 GPU 计算的梯度不同

GPU 0: grad_0 = [-0.1, 0.2, -0.3]
GPU 1: grad_1 = [-0.2, 0.1, -0.4]
...
GPU 7: grad_7 = [-0.15, 0.25, -0.35]

需要: 平均梯度
avg_grad = mean([grad_0, ..., grad_7])

解决方案: 分布式优化器自动同步梯度

7.2 DistAdamW 实现

位置: nanochat/adamw.py

class DistAdamW(torch.optim.Optimizer):
    """分布式 AdamW"""
    
    @torch.no_grad()
    def step(self):
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        
        # 1. Reduce-Scatter: 梯度切片并平均
        for param in params:
            grad = param.grad
            rank_size = grad.shape[0] // world_size
            grad_slice = torch.empty_like(grad[:rank_size])
            
            # 分发梯度的一部分给每个 rank
            dist.reduce_scatter_tensor(
                grad_slice,      # 输出: 当前 rank 的切片
                grad,            # 输入: 完整梯度
                op=dist.ReduceOp.AVG  # 操作: 平均
            )
        
        # 2. 更新参数(每个 rank 只更新自己的切片)
        for param in params:
            p_slice = param[rank * rank_size:(rank + 1) * rank_size]
            # ... AdamW 更新逻辑 ...
            p_slice.add_(update, alpha=-lr)
        
        # 3. All-Gather: 收集所有 rank 的参数
        for param in params:
            dist.all_gather_into_tensor(
                param,      # 输出: 完整参数
                p_slice     # 输入: 当前 rank 的切片
            )

三个关键步骤:

  1. Reduce-Scatter(梯度平均)
    Before:
    GPU 0: [grad_full]
    GPU 1: [grad_full]
    ...
       
    After:
    GPU 0: [avg_grad_slice_0]
    GPU 1: [avg_grad_slice_1]
    ...
    
  2. Update(本地更新)
    GPU 0: 更新 param_slice_0
    GPU 1: 更新 param_slice_1
    ...
    
  3. All-Gather(参数同步)
    Before:
    GPU 0: [param_slice_0]
    GPU 1: [param_slice_1]
    ...
       
    After:
    GPU 0: [param_full]
    GPU 1: [param_full]
    ...
    

7.3 DistMuon 实现

位置: nanochat/muon.py

class DistMuon(torch.optim.Optimizer):
    """分布式 Muon 优化器"""
    
    @torch.no_grad()
    def step(self):
        # 1. Reduce-Scatter: 梯度平均
        for group in self.param_groups:
            for param in group["params"]:
                # 将 world_size 个梯度收集并平均
                dist.reduce_scatter(
                    rs_output,
                    rs_input,
                    op=dist.ReduceOp.AVG
                )
        
        # 2. 正交化 + 更新(每个 rank 的切片)
        for param in params:
            g = averaged_grad
            # ... Muon 逻辑 ...
            g = zeropower_via_newtonschulz5(g)
            param.add_(g, alpha=-lr)
        
        # 3. All-Gather: 同步参数
        for param in params:
            dist.all_gather(ag_output, ag_input)

与 DistAdamW 的区别:

  • 梯度平均方式相同
  • 更新算法不同(正交化 vs 自适应学习率)

7.4 通信开销分析

通信量:

# 假设模型参数量: 100M
# 梯度大小: 100M * 4 bytes (fp32) = 400MB

# Reduce-Scatter: 400MB / 8 GPUs = 50MB per GPU
# All-Gather: 50MB * 8 = 400MB per GPU

# 总通信: 450MB per GPU per step

通信时间:

NVLink 带宽: ~300 GB/s
通信时间: 450MB / 300GB/s ≈ 1.5ms

前向+反向: ~50ms
通信占比: 1.5 / 50 = 3%

效率: 97% (非常高!)

8. 实战案例

8.1 完整训练流程(8 卡)

步骤 1: Base 训练

# 启动 Base 训练
torchrun --standalone --nproc_per_node=8 \
    -m scripts.base_train \
    --depth=12 \
    --device_batch_size=32 \
    --num_iterations=50000 \
    --run=base_d12_8gpu

# 预期输出(每个 GPU):
# Step 0: loss=4.234 (8 个进程同时打印)
# Step 1: loss=4.123 (但只有 rank 0 保存日志)

步骤 2: 监控训练

# 查看 GPU 使用
watch -n 1 nvidia-smi

# 应该看到 8 张 GPU 都在运行
# GPU 0: Python (28GB / 80GB)
# GPU 1: Python (28GB / 80GB)
# ...
# GPU 7: Python (28GB / 80GB)

步骤 3: Mid 训练

torchrun --standalone --nproc_per_node=8 \
    -m scripts.mid_train \
    --source=base \
    --device_batch_size=16

步骤 4: SFT 训练

torchrun --standalone --nproc_per_node=8 \
    -m scripts.chat_sft \
    --source=mid \
    --device_batch_size=8

步骤 5: RL 训练

torchrun --standalone --nproc_per_node=8 \
    -m scripts.chat_rl \
    --source=sft \
    --device_batch_size=8 \
    --num_samples=16

8.2 单卡 vs 多卡对比

相同的总批次大小

# 目标: total_batch_size = 524,288 tokens

# ===== 单卡 =====
python -m scripts.base_train \
    --device_batch_size=8 \
    --max_seq_len=2048
# 计算: 8 * 2048 = 16,384 tokens/step
# 需要梯度累积: 524,288 / 16,384 = 32 steps
# 训练时间: ~16 小时

# ===== 8 卡 =====
torchrun --nproc_per_node=8 -m scripts.base_train \
    --device_batch_size=32 \
    --max_seq_len=2048
# 计算: 32 * 2048 * 8 = 524,288 tokens/step
# 需要梯度累积: 1 step (无累积!)
# 训练时间: ~2 小时

对比表格:

指标单卡8卡提升
device_batch_size8324x
梯度累积步数32132x 更快
训练时间16h2h8x
最终 Loss3.453.45相同

8.3 调试技巧

检查分布式状态

# 在训练脚本中添加
print(f"Rank {ddp_rank}/{ddp_world_size} on GPU {ddp_local_rank}")

# 应该看到:
# Rank 0/8 on GPU 0
# Rank 1/8 on GPU 1
# ...
# Rank 7/8 on GPU 7

验证数据不重复

# 在数据加载器中
print(f"Rank {ddp_rank} processing examples: {list(range(ddp_rank, 100, ddp_world_size))}")

# 输出:
# Rank 0: [0, 8, 16, 24, ...]
# Rank 1: [1, 9, 17, 25, ...]
# ...

验证梯度同步

# 在训练循环中
if step % 100 == 0 and ddp:
    # 检查所有 rank 的第一个参数是否相同
    param = list(model.parameters())[0]
    local_sum = param.sum().item()
    print(f"Rank {ddp_rank}: param sum = {local_sum}")
    
    # 所有 rank 应该打印相同的值!

9. 常见问题

9.1 环境问题

Q: NCCL error: unhandled system error

原因: 网络配置问题

解决:

# 方法 1: 设置网络接口
export NCCL_SOCKET_IFNAME=eth0  # 或 ens3, ib0 等

# 方法 2: 使用 IB 网络(如果有)
export NCCL_IB_DISABLE=0

# 方法 3: 禁用 IB(如果不需要)
export NCCL_IB_DISABLE=1

Q: RuntimeError: Address already in use

原因: 端口被占用

解决:

# 更换端口
torchrun --master_port=29501 --nproc_per_node=8 ...

# 或者杀死之前的进程
pkill -9 python

9.2 性能问题

Q: 多卡训练没有加速

可能原因 1:通信瓶颈

# 检查 GPU 连接方式
nvidia-smi topo -m

# 理想情况: NVLink
# GPU0 <-> GPU1: NV12 (12条 NVLink)

# 不好: PCIe
# GPU0 <-> GPU1: PHB (通过 PCIe Host Bridge)

可能原因 2:数据加载慢

# 增加数据加载线程
loader = DataLoader(
    dataset,
    num_workers=4,  # 增加到 4-8
    pin_memory=True
)

可能原因 3:批次太小

# 批次太小,通信开销占比高
device_batch_size = 1  # 太小 ❌
device_batch_size = 32  # 合适 ✅

Q: 显存占用不均衡

现象:

GPU 0: 60GB / 80GB
GPU 1: 30GB / 80GB

原因: 数据不均衡或主进程额外操作

解决:

# 确保数据均匀分配
assert len(dataset) % world_size == 0

# 主进程避免额外的显存操作
if ddp_rank == 0:
    # 在 CPU 上评估,不在 GPU 上
    model.cpu().eval()

9.3 训练问题

Q: Loss 在不同 rank 不一致

原因: 模型没有正确同步

检查:

# 1. 确保使用分布式优化器
optimizer = DistAdamW(...)  # ✅
optimizer = torch.optim.AdamW(...)  # ❌ 不会同步

# 2. 确保调用 barrier
if ddp:
    dist.barrier()

Q: 检查点保存/加载问题

最佳实践:

# 保存: 只有主进程保存
if ddp_rank == 0:
    torch.save(model.state_dict(), path)

# 加载: 所有进程都加载
model.load_state_dict(torch.load(path))

# 同步
if ddp:
    dist.barrier()

Q: 如何从单卡切换到多卡?

非常简单!

# 1. 单卡训练
python -m scripts.base_train --depth=12

# 2. 多卡训练(只需添加 torchrun)
torchrun --nproc_per_node=8 -m scripts.base_train --depth=12

# 代码完全不需要修改!
# nanochat 已经处理好了所有细节

10. 高级话题

10.1 多节点训练

如果你有多台机器(例如 2 台,每台 8 卡):

# 节点 0(主节点)
torchrun \
    --nnodes=2 \
    --node_rank=0 \
    --master_addr=192.168.1.100 \
    --master_port=29500 \
    --nproc_per_node=8 \
    -m scripts.base_train

# 节点 1
torchrun \
    --nnodes=2 \
    --node_rank=1 \
    --master_addr=192.168.1.100 \
    --master_port=29500 \
    --nproc_per_node=8 \
    -m scripts.base_train

10.2 混合精度训练

nanochat 已经使用 bfloat16:

# 自动混合精度
autocast_ctx = torch.amp.autocast(
    device_type="cuda",
    dtype=torch.bfloat16  # 训练更稳定
)

with autocast_ctx:
    loss = model(inputs, targets)

10.3 梯度检查点

对于更大的模型(depth > 20):

# 在模型中启用
model = torch.compile(model)
# PyTorch 2.0 自动优化内存使用

总结

多卡训练的核心要点

  1. 启动方式
    torchrun --nproc_per_node=8 -m scripts.xxx
    
  2. 数据分配
    for i in range(ddp_rank, len(data), ddp_world_size):
        process(data[i])
    
  3. 梯度同步
    # 自动由 DistAdamW / DistMuon 处理
    optimizer.step()
    
  4. 主进程检查
    if ddp_rank == 0:
        save_checkpoint()
    

为什么 nanochat 的多卡训练这么简单?

  1. 自动检测:代码自动检测是否在分布式环境
  2. 透明处理:分布式优化器自动同步梯度
  3. 无需修改:单卡代码直接支持多卡
  4. 工具函数print0(), get_dist_info() 等简化开发

下一步

  • 实际运行多卡训练
  • 监控 GPU 使用和训练速度
  • 尝试不同的 batch size 配置
  • 理解通信模式和优化策略

参考资源:

标签: