noao-chat-2-mid阶段训练

24 分钟阅读时长

发布时间:

nonochat - LLM Mid 训练完整解析

nonochat

本文档详细解析 mid_train.py,讲解中间训练(Midtraining)的作用和实现细节。
适合 LLM 领域的初学者,从 Base 模型到应用模型的关键过渡阶段。


目录

  1. 什么是 Mid 训练
  2. 数据来源的巨大变化
  3. 训练流程详解
  4. 与 Base 训练的对比
  5. 实战案例分析

1. 什么是 Mid 训练

1.1 训练阶段定位

Base Training → Mid Training → SFT → RL
     ↓              ↓            ↓     ↓
  通用能力      领域能力      对话   对齐
  (海量文本)  (结构化任务)   (指令)  (偏好)

Mid Training(中间训练) 是连接预训练和微调的桥梁阶段。

1.2 为什么需要 Mid Training?

Base 模型的局限:

  • ✅ 有通用语言理解能力
  • ❌ 不懂对话格式
  • ❌ 不会使用工具
  • ❌ 缺乏特定领域知识
  • ❌ 在某些任务上表现不佳

直接 SFT 的问题:

  • SFT 数据量通常很少(几千到几万条)
  • 如果 Base 模型在某个能力上很弱,SFT 很难补救
  • 某些基础能力需要更多数据才能学会

Mid Training 的解决方案:

  • 🎯 使用结构化对话数据继续训练
  • 🎯 注入特定领域知识(数学、对话、工具使用)
  • 🎯 补齐 Base 模型的能力短板
  • 🎯 为后续 SFT 打下良好基础

1.3 Mid Training 的核心目标

# 训练数据组成(总共 ~850K 条对话)
train_dataset = TaskMixture([
    SmolTalk(split="train"),              # 460K 通用对话
    MMLU(subset="auxiliary_train"),        # 100K 知识问答
    GSM8K(subset="main", split="train"),   # 8K 数学推理
    CustomJSON(identity_file),             # 1K 身份对话 × 2 轮
    SimpleSpelling(size=200000),           # 200K 拼写
    SpellingBee(size=80000),               # 80K 字母计数
])

关键能力注入:

能力类型数据来源数量目的
对话能力SmolTalk460K学会自然对话
知识储备MMLU100K多领域知识
数学推理GSM8K8K工具使用 + 推理
身份设定自定义1K×2模型人格
拼写能力Spelling280Ktoken→字符映射

2. 数据来源的巨大变化

2.1 Base vs Mid 数据对比

Base Training 数据格式

原始文本(Parquet 文件):
"The quick brown fox jumps over the lazy dog. This is a test..."

Tokenize → 连续的 token 流:
[15496, 995, 831, 374, 264, 1296, ...]

分批次:
inputs:  [15496, 995, 831, 374, ...]
targets: [995, 831, 374, 264, ...]

特点:

  • 📄 纯文本,没有结构
  • 🔄 文档边界被 <|bos|> 分隔
  • 🎲 随机拼接文档

Mid Training 数据格式

结构化对话(来自任务数据集):
{
    "messages": [
        {"role": "user", "content": "What is 2+2?"},
        {"role": "assistant", "content": "4"}
    ]
}

Tokenize → 带格式的 token 序列:
[<|bos|>, <|user_start|>, "What", "is", "2", "+", "2", "?", <|user_end|>, 
 <|assistant_start|>, "4", <|assistant_end|>]

分批次(保留对话结构):
# 多个完整对话拼接在一起

特点:

  • 💬 结构化对话
  • 🏷️ 明确的角色标记
  • 📚 来自精心策划的任务

2.2 任务数据集详解

2.2.1 SmolTalk - 通用对话

来源: HuggingFace SmolTalk 数据集

格式示例:

{
    "messages": [
        {"role": "user", "content": "Can you explain quantum computing?"},
        {"role": "assistant", "content": "Quantum computing uses quantum bits..."},
        {"role": "user", "content": "What are its applications?"},
        {"role": "assistant", "content": "Quantum computers can..."}
    ]
}

特点:

  • 多轮对话
  • 涵盖各种日常话题
  • 自然的对话风格
  • 训练集:460K 条,测试集:24K 条

作用: 让模型学会自然的多轮对话模式

2.2.2 MMLU - 知识问答

来源: MMLU (Massive Multitask Language Understanding)

格式示例:

{
    "messages": [
        {
            "role": "user", 
            "content": "What is the capital of France?\nA. London\nB. Paris\nC. Berlin\nD. Madrid"
        },
        {"role": "assistant", "content": "B"}
    ],
    "subject": "geography",
    "letters": ["A", "B", "C", "D"]
}

特点:

  • 选择题格式
  • 57 个学科领域
  • 辅助训练集:100K 条

学科分布:

STEM: 数学、物理、化学、生物、计算机...
人文: 历史、哲学、法律、心理学...
社科: 经济学、社会学、政治学...
其他: 医学、商业、艺术...

作用: 注入大量结构化知识

2.2.3 GSM8K - 数学推理

来源: GSM8K (Grade School Math 8K)

格式示例:

{
    "messages": [
        {
            "role": "user",
            "content": "Weng earns $12 an hour. Yesterday, she worked for 50 minutes. How much did she earn?"
        },
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": "First, let's calculate per minute wage:\n"},
                {"type": "python", "text": "12/60"},
                {"type": "python_output", "text": "0.2"},
                {"type": "text", "text": "\nSo she earns $0.2 per minute. For 50 minutes:\n"},
                {"type": "python", "text": "0.2*50"},
                {"type": "python_output", "text": "10"},
                {"type": "text", "text": "\n#### 10"}
            ]
        }
    ]
}

关键特性:

  • 🧮 包含工具调用(Python 计算器)
  • 📝 步骤化推理
  • ✅ 最终答案标记 #### 10

Token 化后的格式:

<|bos|>
<|user_start|> Weng earns $12 an hour... <|user_end|>
<|assistant_start|> 
    First, let's calculate per minute wage:
    <|python_start|> 12/60 <|python_end|>
    <|output_start|> 0.2 <|output_end|>
    So she earns $0.2 per minute. For 50 minutes:
    <|python_start|> 0.2*50 <|python_end|>
    <|output_start|> 10 <|output_end|>
    #### 10
<|assistant_end|>

作用:

  • 教会模型使用工具
  • 学习步骤化推理
  • 理解计算流程

2.2.4 Spelling Tasks - 拼写能力

为什么需要?

LLM 的一个常见弱点:不知道 token 是如何拼写的。

问题示例:
"How many 'r' are in 'strawberry'?"

Base 模型可能回答: "2"  ❌
正确答案: "3"  ✅

原因分析:

Token 化:
"strawberry" → [str, aw, berry]  # 可能被切成多个 token

模型看到的是 token IDs,不是字符:
[24871, 675, 15717]

要回答"有几个 r",模型需要知道:
- token 24871 包含哪些字母?
- token 675 包含哪些字母?
- token 15717 包含哪些字母?

这种 token → 字符 的映射需要专门训练!

解决方案:SimpleSpelling

{
    "messages": [
        {"role": "user", "content": "Spell the word 'apple'"},
        {"role": "assistant", "content": "a-p-p-l-e"}
    ]
}
  • 200K 个单词拼写任务
  • 让模型学会 token 的字符组成

解决方案:SpellingBee

{
    "messages": [
        {"role": "user", "content": "How many r are in strawberry"},
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": "Let me spell it:\n"},
                {"type": "python", "text": "'strawberry'.lower()"},
                {"type": "python_output", "text": "strawberry"},
                {"type": "text", "text": "\nNow count:\n"},
                {"type": "python", "text": "'strawberry'.count('r')"},
                {"type": "python_output", "text": "3"},
                {"type": "text", "text": "\n#### 3"}
            ]
        }
    ]
}
  • 80K 个字母计数任务
  • 结合拼写和工具使用

2.2.5 Identity Conversations - 身份设定

格式示例:

{
    "messages": [
        {"role": "user", "content": "Who are you?"},
        {"role": "assistant", "content": "I am an AI assistant created by..."}
    ]
}

作用:

  • 让模型知道自己是谁
  • 统一回答身份问题
  • 增强一致性

2.3 数据加载器对比

Base Training 数据加载

def tokenizing_distributed_data_loader(B, T, split):
    # 读取 Parquet 文件
    pf = pq.ParquetFile(filepath)
    batch = rg.column('text').to_pylist()
    
    # Tokenize
    token_lists = tokenizer.encode(batch, prepend=bos_token)
    
    # 拼接成连续流
    for tokens in token_lists:
        token_buffer.extend(tokens)
    
    # 取出固定长度
    needed = B * T + 1
    tokens = [token_buffer.popleft() for _ in range(needed)]
    
    # 切分
    inputs = tokens[:-1].view(B, T)
    targets = tokens[1:].view(B, T)

Mid Training 数据加载

def mid_data_generator(split):
    dataset = train_dataset  # TaskMixture
    
    # 遍历数据集
    for cursor in range(ddp_rank, len(dataset), ddp_world_size):
        # 获取一个对话
        conversation = dataset[cursor]
        
        # Tokenize(保留结构)
        ids, mask = tokenizer.render_conversation(conversation)
        # ids: [<|bos|>, <|user_start|>, ..., <|assistant_end|>]
        
        # 累积到缓冲区
        token_buffer.extend(ids)
        
        # 当缓冲区足够时,切分批次
        if len(token_buffer) >= needed_tokens:
            tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
            inputs = tokens[:-1].view(B, T)
            targets = tokens[1:].view(B, T)
            yield inputs, targets

关键差异:

维度Base TrainingMid Training
数据源Parquet 文件(纯文本)Task 对象(结构化对话)
分词方式encode(text)render_conversation(conv)
结构无结构带角色标记
边界只有 <|bos|>完整的对话标记
对话完整性文档可能被截断尽量保持对话完整

3. 训练流程详解

3.1 完整数据流(图解)

步骤 1: 加载任务数据
┌─────────────────────────────────────┐
│ SmolTalk: 460K 对话                  │
│ MMLU: 100K 问答                      │
│ GSM8K: 8K 数学题                     │
│ Spelling: 280K 拼写任务              │
│ Identity: 1K 身份对话                │
└─────────────────────────────────────┘
            ↓
步骤 2: TaskMixture 混合
┌─────────────────────────────────────┐
│ 所有任务混合打乱                      │
│ 总计: ~850K 条对话                   │
│ 确定性随机打乱 (seed=42)              │
└─────────────────────────────────────┘
            ↓
步骤 3: 按索引遍历
┌─────────────────────────────────────┐
│ 每个 GPU 负责不同的子集                │
│ GPU 0: indices [0, 8, 16, 24, ...]   │
│ GPU 1: indices [1, 9, 17, 25, ...]   │
│ ...                                  │
└─────────────────────────────────────┘
            ↓
步骤 4: 获取对话
┌─────────────────────────────────────┐
│ conversation = dataset[cursor]       │
│ {                                    │
│   "messages": [                      │
│     {"role": "user", "content": ...} │
│     {"role": "assistant", ...}       │
│   ]                                  │
│ }                                    │
└─────────────────────────────────────┘
            ↓
步骤 5: Tokenize(保留结构)
┌─────────────────────────────────────┐
│ ids, mask = tokenizer.render_conversation(conv) │
│                                      │
│ ids: [<|bos|>, <|user_start|>, ...] │
│ mask: [0, 0, 0, 1, 1, ...]  # 训练掩码 │
└─────────────────────────────────────┘
            ↓
步骤 6: 累积到缓冲区
┌─────────────────────────────────────┐
│ token_buffer.extend(ids)             │
│ [1, 2, 3, 4, ..., 10000, 10001, ...] │
│ (多个对话连接在一起)                  │
└─────────────────────────────────────┘
            ↓
步骤 7: 切分批次
┌─────────────────────────────────────┐
│ needed = B * T + 1  # 32*2048+1      │
│ tokens = buffer[:needed]             │
│                                      │
│ inputs:  (B, T) = (32, 2048)         │
│ targets: (B, T) = (32, 2048)         │
└─────────────────────────────────────┘
            ↓
步骤 8: 送入模型训练
┌─────────────────────────────────────┐
│ loss = model(inputs, targets)        │
│ loss.backward()                      │
│ optimizer.step()                     │
└─────────────────────────────────────┘

3.2 TokenMixture 的作用

问题: 不同任务的数据量差异巨大

SmolTalk: 460K  ████████████████████████
MMLU:     100K  █████
GSM8K:      8K  ▌
Identity:   2K  ▏

简单遍历的问题:

# 错误做法:依次训练每个任务
for task in tasks:
    for example in task:
        train(example)

# 结果:前期都在训练 SmolTalk,后期才见到 GSM8K
# 模型会遗忘早期学到的东西!

TaskMixture 的解决方案:

class TaskMixture:
    def __init__(self, tasks):
        # 1. 构建索引映射
        self.index_map = []
        for task_idx, task in enumerate(tasks):
            for local_idx in range(len(task)):
                self.index_map.append((task_idx, local_idx))
        
        # 2. 打乱索引(确定性)
        rng = random.Random(42)
        rng.shuffle(self.index_map)
        
        # 现在访问顺序是:
        # [SmolTalk_123, MMLU_45, GSM8K_7, SmolTalk_456, ...]

效果:

训练过程中看到的数据(混合后):
[S, S, M, S, S, G, S, M, S, I, S, S, M, S, G, ...]
 ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑  ↑
 SmolTalk, MMLU, GSM8K, Identity 混合出现

而不是:
[S, S, S, S, ..., M, M, M, ..., G, G, ...]
 ←── SmolTalk ───→  ←MMLU→  ←GSM8K→

优点:

  • ✅ 所有任务在整个训练过程中都有曝光
  • ✅ 避免灾难性遗忘
  • ✅ 不同能力同步提升

3.3 训练循环详解

3.3.1 进度追踪机制

Mid Training 的一个特殊之处:我们不提前知道要训练多少步

# Base Training: 明确的迭代次数
num_iterations = 10000  # 确定
for step in range(num_iterations):
    ...

# Mid Training: 基于数据集大小
dataset_size = 850000  # 对话数量
# 每个 step 消耗多少对话?不确定!(因为对话长度不一)
# 所以用进度百分比

进度计算:

# 全局变量
last_step = False       # 是否到达最后一步
approx_progress = 0.0   # 0 → 1

# 在数据生成器中更新
def mid_data_generator(split):
    cursor = ddp_rank
    it = 0
    
    while True:
        # 处理一个对话
        conversation = dataset[cursor]
        cursor += ddp_world_size
        
        # 更新进度
        if cursor >= dataset_size:
            cursor -= dataset_size  # 回绕(下一个 epoch)
            if split == "train":
                last_step = True  # 标记为最后一步
        
        # 计算近似进度
        approx_progress = cursor / dataset_size
        
        yield inputs, targets

多卡同步:

# 问题:不同 GPU 可能处理速度不同
# GPU 0: last_step = True
# GPU 1: last_step = False  ← 还没完成

# 解决:同步 last_step
if ddp:
    last_step_tensor = torch.tensor(last_step, device=device)
    dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
    # 只要有一个 GPU 到达终点,所有 GPU 都停止
    last_step = bool(last_step_tensor.item())

3.3.2 学习率调度

与 Base Training 不同:

# Base Training: 基于步数
def get_lr_multiplier(step):
    if step < warmup_iters:
        return step / warmup_iters
    elif step <= num_iterations - warmdown_iters:
        return 1.0
    else:
        return linear_decay(...)

# Mid Training: 基于进度百分比
def get_lr_multiplier(progress):
    # progress: 0.0 → 1.0
    if progress < 0.8:
        return 1.0  # 前 80% 保持不变
    else:
        # 后 20% 线性衰减到 0
        return 1 - (progress - 0.8) / 0.2

学习率曲线:

LR Multiplier
    1.0 |════════════════════════════════╲
        |                                 ╲
        |                                  ╲
    0.5 |                                   ╲
        |                                    ╲
    0.0 |                                     ═
        +───────────────────────────────────────→
        0%      20%     40%     60%     80%   100% progress
        ←────────── 保持 1.0 ──────────→ ← 衰减 →

为什么不同?

  • Base Training 知道总步数,可以精确规划
  • Mid Training 不确定总步数,用进度百分比更灵活

3.3.3 单步训练流程

# 预取第一批数据
x, y = next(train_loader)

while True:
    # === 评估阶段 ===
    if step % eval_every == 0:
        model.eval()
        val_bpb = evaluate_bpb(...)
        model.train()
    
    # === 保存检查点 ===
    if last_step:
        save_checkpoint(...)
        break
    
    # === 训练阶段 ===
    synchronize()
    t0 = time.time()
    
    # 梯度累积
    for micro_step in range(grad_accum_steps):
        with autocast_ctx:
            loss = model(x, y)
        
        loss = loss / grad_accum_steps
        loss.backward()
        
        # 异步预取下一批
        x, y = next(train_loader)
        progress = max(progress, approx_progress)
    
    # 更新学习率
    lrm = get_lr_multiplier(progress)
    for opt in optimizers:
        for group in opt.param_groups:
            group["lr"] = group["initial_lr"] * lrm
    
    # 更新参数
    for opt in optimizers:
        opt.step()
    model.zero_grad()
    
    synchronize()
    t1 = time.time()
    
    # === 日志记录 ===
    step += 1
    dt = t1 - t0
    print(f"step {step} ({progress*100:.2f}%) | loss: {loss:.6f} ...")

3.4 维度分析

假设配置:B=32, T=2048, vocab_size=32000, n_embd=768

单个对话的 Token 化

# 输入:结构化对话
conversation = {
    "messages": [
        {"role": "user", "content": "What is 2+2?"},
        {"role": "assistant", "content": "4"}
    ]
}

# 输出:token IDs + mask
ids, mask = tokenizer.render_conversation(conversation)

# ids: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
# mask: [0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0]

# 展开看:
# token_id | token_str           | mask | 说明
# ---------|---------------------|------|------
# 1        | <|bos|>             | 0    | 文档开始
# 2        | <|user_start|>      | 0    | 用户开始
# 3-7      | "What is 2+2?"      | 0    | 用户消息(不训练)
# 8        | <|user_end|>        | 0    | 用户结束
# 9        | <|assistant_start|> | 0    | 助手开始
# 10-11    | "4"                 | 1    | 助手回复(训练!)
# 12       | <|assistant_end|>   | 1    | 助手结束(训练)

批次维度

# 多个对话拼接
conversation_1: 200 tokens
conversation_2: 150 tokens
conversation_3: 180 tokens
...

# 累积到缓冲区
token_buffer: [tok1, tok2, ..., tok_N]  # N 很大

# 切分批次
needed = 32 * 2048 + 1 = 65537

tokens = token_buffer[:65537]
# [tok_1, tok_2, ..., tok_65537]

# 重塑
inputs = tokens[:-1].view(32, 2048)   # (32, 2048)
targets = tokens[1:].view(32, 2048)   # (32, 2048)

模型前向传播

# 输入
inputs: (32, 2048)  # int32

# Embedding
x = model.transformer.wte(inputs)
# (32, 2048) → (32, 2048, 768)

# Transformer Blocks
for block in model.transformer.h:
    x = block(x, ...)
# (32, 2048, 768) → (32, 2048, 768)

# LM Head
logits = model.lm_head(x)
# (32, 2048, 768) @ (768, 32000) → (32, 2048, 32000)

# 损失计算
loss = F.cross_entropy(
    logits.view(-1, 32000),   # (65536, 32000)
    targets.view(-1)          # (65536,)
)
# loss: 标量

4. 与 Base 训练的对比

4.1 核心差异总结

维度Base TrainingMid Training
数据来源Parquet 文件(网页抓取)任务数据集(精心策划)
数据格式纯文本结构化对话
数据量数十亿 tokens数百万对话
训练目标通用语言理解特定能力注入
Token 格式连续文本流带角色标记
进度追踪固定步数数据集进度
学习率基于步数调度基于进度百分比
训练时长数天到数周数小时到1天
检查点位置base_checkpoints/mid_checkpoints/

4.2 代码层面对比

数据加载器

# ===== Base Training =====
# 从 Parquet 读取
pf = pq.ParquetFile(filepath)
batch = rg.column('text').to_pylist()

# 简单 tokenize
token_lists = tokenizer.encode(batch, prepend=bos_token)

# 无结构,直接拼接
for tokens in token_lists:
    token_buffer.extend(tokens)

# ===== Mid Training =====
# 从 Task 对象读取
conversation = dataset[cursor]

# 结构化 tokenize
ids, mask = tokenizer.render_conversation(conversation)
# 包含 <|user_start|>, <|assistant_start|> 等特殊 token

# 保留对话结构
token_buffer.extend(ids)

模型初始化

# ===== Base Training =====
# 从零开始
with torch.device("meta"):
    model_config = GPTConfig(...)
    model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()  # 随机初始化

# ===== Mid Training =====
# 从 Base 模型加载
model, tokenizer, meta = load_model(
    "base",  # 加载 base 模型
    device,
    phase="train",
    model_tag=model_tag,
    step=step
)
# 参数已经训练过,继续训练

训练循环

# ===== Base Training =====
for step in range(num_iterations):  # 固定次数
    loss = model(x, y)
    loss.backward()
    optimizer.step()

# ===== Mid Training =====
while True:  # 直到遍历完数据集
    loss = model(x, y)
    loss.backward()
    optimizer.step()
    
    if last_step:  # 数据集遍历完
        break

4.3 超参数对比

# ===== Base Training =====
base_config = {
    "total_batch_size": 524288,      # 512K tokens/step
    "device_batch_size": 32,
    "num_iterations": 50000,         # 明确指定
    "warmup_ratio": 0.0,             # 无 warmup
    "warmdown_ratio": 0.2,           # 最后 20% 衰减
    "embedding_lr": 0.2,
    "matrix_lr": 0.02,
    "unembedding_lr": 0.004,
    "init_lr_frac": 1.0,             # 从 100% 开始
}

# ===== Mid Training =====
mid_config = {
    "total_batch_size": 524288,      # 512K tokens/step (相同)
    "device_batch_size": 32,
    "num_iterations": -1,            # 自动推断
    "warmup_ratio": 0.0,             # 无 warmup
    "warmdown_ratio": 0.2,           # 最后 20% 衰减
    "embedding_lr": 0.2,
    "matrix_lr": 0.02,
    "unembedding_lr": 0.004,
    "init_lr_frac": 1.0,             # 从 100% 开始
}

相同点:

  • 批次大小相同
  • 学习率相同
  • 优化器配置相同

不同点:

  • Mid Training 从已训练的模型开始
  • 数据源完全不同
  • 训练目标不同

5. 实战案例分析

5.1 完整训练流程示例

假设我们训练一个 depth=12 的模型(约 200M 参数)。

步骤 1: Base Training

# 8 卡 A100,训练 3 天
torchrun --nproc_per_node=8 -m scripts.base_train \
    --depth=12 \
    --device_batch_size=32 \
    --total_batch_size=524288 \
    --num_iterations=50000 \
    --run=base_d12

# 输出:
# base_checkpoints/d12/step_00050000.pt

结果:

  • ✅ 模型学会了通用语言理解
  • ✅ 能够补全句子
  • ❌ 不懂对话格式
  • ❌ 不会使用工具

步骤 2: Mid Training

# 8 卡 A100,训练 6 小时
torchrun --nproc_per_node=8 -m scripts.mid_train \
    --device_batch_size=32 \
    --total_batch_size=524288 \
    --run=mid_d12

# 输出:
# mid_checkpoints/d12/step_XXXXX.pt

结果:

  • ✅ 理解对话格式(user/assistant)
  • ✅ 学会了基本对话能力
  • ✅ 能使用 Python 计算器
  • ✅ 拼写能力显著提升
  • ✅ 多领域知识增强
  • ❌ 还不能很好地遵循指令

5.2 训练数据流转示例

让我们追踪一个实际的训练 batch:

Batch 构成(B=4, T=512 为例)

# 假设 token_buffer 中有以下对话(简化):

# 对话 1: SmolTalk(200 tokens)
[<|bos|>, <|user_start|>, "How", "are", "you", "?", <|user_end|>,
 <|assistant_start|>, "I'm", "doing", "well", "!", <|assistant_end|>, ...]

# 对话 2: MMLU(50 tokens)
[<|bos|>, <|user_start|>, "What", "is", "2+2", "?", "\n", "A.", "2", ...,
 <|user_end|>, <|assistant_start|>, "D", <|assistant_end|>]

# 对话 3: GSM8K(300 tokens)
[<|bos|>, <|user_start|>, "Weng", "earns", ..., <|user_end|>,
 <|assistant_start|>, "First", ",", "let's", ..., <|python_start|>, "12/60", ...]

# 对话 4: SpellingBee(100 tokens)
[<|bos|>, <|user_start|>, "How", "many", "r", "in", "strawberry", <|user_end|>,
 <|assistant_start|>, <|python_start|>, "'strawberry'.count('r')", ...]

# 从 buffer 中取出 4*512+1 = 2049 个 token
# 这些 token 来自上述对话的混合

切分成 inputs 和 targets

tokens = token_buffer[:2049]

inputs = tokens[:-1]   # 前 2048 个
targets = tokens[1:]   # 后 2048 个

# 重塑
inputs = inputs.view(4, 512)
targets = targets.view(4, 512)

# 现在每个样本包含多个对话的片段
# 样本 0: 对话1的一部分 + 对话2的开头
# 样本 1: 对话2的剩余 + 对话3的一部分
# 样本 2: 对话3的中间部分
# 样本 3: 对话3的结尾 + 对话4

模型训练

# 前向传播
logits = model(inputs)  # (4, 512, 32000)

# 计算损失
loss = F.cross_entropy(
    logits.view(-1, 32000),
    targets.view(-1)
)

# 反向传播
loss.backward()

# 更新参数
optimizer.step()

5.3 训练日志解读

step 00100 (12.50%) | loss: 2.345678 | lrm: 1.00 | dt: 245.32ms | 
tok/sec: 2,137,856 | mfu: 45.67 | total time: 8.52m

step 00200 (25.00%) | loss: 2.123456 | lrm: 1.00 | dt: 243.11ms | 
tok/sec: 2,156,234 | mfu: 46.12 | total time: 17.21m

...

step 01600 (80.00%) | loss: 1.876543 | lrm: 1.00 | dt: 242.55ms | 
tok/sec: 2,161,345 | mfu: 46.23 | total time: 137.45m

step 01700 (85.00%) | loss: 1.865432 | lrm: 0.75 | dt: 242.88ms | 
tok/sec: 2,159,876 | mfu: 46.19 | total time: 145.78m
                                      ^^^^
                                      学习率开始衰减

step 02000 (100.00%) | loss: 1.854321 | lrm: 0.00 | dt: 243.21ms | 
tok/sec: 2,157,234 | mfu: 46.14 | total time: 170.12m

观察:

  • 进度从 0% → 100%
  • Loss 从 2.3 → 1.8(显著下降)
  • 在 80% 之后,lrm 从 1.0 开始衰减
  • 总训练时间约 2.8 小时(8 卡)

5.4 能力提升对比

测试任务:

prompts = [
    # 1. 对话能力
    "Hello! How are you?",
    
    # 2. 知识问答
    "What is the capital of France?",
    
    # 3. 数学推理
    "If I have 5 apples and buy 3 more, how many do I have?",
    
    # 4. 拼写能力
    "How many 'r' are in 'strawberry'?",
]

模型对比:

测试Base 模型Mid 模型
对话“I are you” (语法错误)“I’m doing well, thanks! How about you?” ✅
知识“Paris France city” (不连贯)“Paris” ✅
数学“8 apples total” (直接猜)“5 + 3 = 8” ✅ (使用计算)
拼写“2” ❌“3” ✅

6. 关键要点总结

6.1 Mid Training 的本质

Mid Training = Base Model + Structured Tasks
             = 通用能力 + 特定技能
             = 为 SFT 做准备

6.2 何时需要 Mid Training?

需要 Mid Training:

  • ✅ Base 模型在某些能力上很弱(如拼写、工具使用)
  • ✅ 有大量结构化任务数据
  • ✅ 想要注入特定领域知识
  • ✅ 数据量够大(数十万级别)

可以跳过 Mid Training:

  • ❌ Base 模型已经足够强
  • ❌ 只有少量 SFT 数据
  • ❌ 资源有限

6.3 Mid Training vs SFT

维度Mid TrainingSFT
数据量大(数十万)小(数千)
目标能力扩展指令对齐
学习率较高较低
训练轮数1 epoch1-3 epochs
数据来源公开数据集精心标注

6.4 实践建议

数据配比:

# 推荐配比
train_dataset = TaskMixture([
    SmolTalk(...),          # 50-60% 通用对话
    Knowledge_QA(...),      # 15-20% 知识问答
    Math_Reasoning(...),    # 1-5% 数学推理
    Tool_Use(...),          # 1-5% 工具使用
    Spelling(...),          # 20-30% 拼写/基础能力
    Identity(...),          # <1% 身份设定(过采样)
])

训练时长:

  • 小模型(<500M):1-2 小时(8卡)
  • 中模型(500M-2B):4-8 小时(8卡)
  • 大模型(>2B):1-2 天(8卡)

检查点保存:

  • 只保存最后的检查点
  • Mid Training 通常不需要中间检查点

附录

A. 完整训练命令

# 单 GPU 训练
python -m scripts.mid_train \
    --device_batch_size=16 \
    --run=mid_test

# 多 GPU 训练
torchrun --standalone --nproc_per_node=8 \
    -m scripts.mid_train \
    --device_batch_size=32 \
    --total_batch_size=524288 \
    --run=mid_production

B. 任务数据集详细统计

任务训练集测试集平均长度用途
SmolTalk460K24K~150 tokens对话能力
MMLU (aux)100K14K~100 tokens知识问答
GSM8K8K1.3K~250 tokens数学推理
SimpleSpelling200K-~30 tokens拼写基础
SpellingBee80K-~100 tokens字母计数
Identity1K-~80 tokens身份设定

C. 常见问题

Q1: Mid Training 必须做吗?

  • 不是必须,但强烈推荐
  • 特别是模型较小(<1B)时

Q2: 可以做多轮 Mid Training 吗?

  • 可以,但通常 1 轮就够
  • 更多轮可能过拟合

Q3: 如何选择任务?

  • 根据应用场景选择
  • 保证数据质量
  • 注意任务间的平衡

Q4: Mid Training 后 loss 应该是多少?

  • 通常比 Base Training 略低
  • 典型值:1.5 - 2.0
  • 重要的是相对下降,不是绝对值

9. 关键疑问解答:为什么 Mid 训练丢弃了 Mask?

9.1 Mask 机制的本质

在 tokenizer 中,render_conversation 函数会返回两个值:

ids, mask = tokenizer.render_conversation(conversation)
# ids: [token_id1, token_id2, ...]  完整的 token 序列
# mask: [0, 0, 1, 1, 0, ...]        标记哪些 token 需要训练

Mask 的含义

  • mask = 0:不计算 loss(如 user 的输入、特殊 token)
  • mask = 1:计算 loss(如 assistant 的回复)

Mask 的生成逻辑(来自 tokenizer.py):

def render_conversation(self, conversation):
    ids, mask = [], []
    
    def add_tokens(token_ids, mask_val):
        ids.extend(token_ids)
        mask.extend([mask_val] * len(token_ids))
    
    # BOS token: mask=0
    add_tokens(bos, 0)
    
    for message in messages:
        if message["role"] == "user":
            # User 的所有内容:mask=0
            add_tokens(user_start, 0)
            add_tokens(content_ids, 0)  # ← User 输入不训练
            add_tokens(user_end, 0)
        
        elif message["role"] == "assistant":
            add_tokens(assistant_start, 0)  # Start token 不训练
            add_tokens(content_ids, 1)      # ← Assistant 回复训练!
            add_tokens(assistant_end, 1)    # End token 也训练
    
    return ids, mask

9.2 Mid Training 丢弃 Mask 的真相

代码对比

# Mid Training (mid_train.py 第 133 行)
ids, _ = tokenizer.render_conversation(conversation)
#      ↑ 直接丢弃 mask!

# SFT Training (chat_sft.py 第 123 行)
ids, mask = tokenizer.render_conversation(doc)
#      ↑ 保留并使用 mask

为什么 Mid Training 丢弃 mask?

答案:Mid Training 在所有 token 上计算 loss,包括 User 的输入!

9.3 Mid vs SFT:训练目标的根本区别

Mid Training 的目标:学习对话流

# Mid Training 的数据处理
conversation = dataset[cursor]
ids, _ = tokenizer.render_conversation(conversation)
token_buffer.extend(ids)  # 直接拼接所有 token

# 生成 inputs, targets(没有 mask)
inputs = tokens[:-1]   # (B, T)
targets = tokens[1:]   # (B, T)

# 计算 loss(所有 token 都参与)
loss = model(inputs, targets)
# CrossEntropyLoss 会在所有 token 上计算梯度

训练的内容

<|bos|><|user_start|>What is 2+2?<|user_end|><|assistant_start|>4<|assistant_end|>
↑      ↑                              ↑          ↑                    ↑
所有这些 token 都参与 loss 计算!

学习的能力

  • ✅ 对话格式(什么时候该是 user,什么时候该是 assistant)
  • ✅ 上下文理解(user 说了什么)
  • ✅ 回复生成(assistant 应该怎么回)
  • ✅ 整体对话流(对话的开始、进行、结束)

SFT Training 的目标:只学习生成回复

# SFT Training 的数据处理
ids, mask = tokenizer.render_conversation(doc)

# 关键步骤:将 mask=0 的位置设为 -1
targets = ids[1:]
mask_tensor = mask[1:]
targets[mask_tensor == 0] = -1  # ← 设为 ignore_index

# 计算 loss(只在 mask=1 的 token 上)
loss = F.cross_entropy(
    logits.view(-1, vocab_size),
    targets.view(-1),
    ignore_index=-1  # ← targets=-1 的位置不计算 loss
)

训练的内容

<|bos|><|user_start|>What is 2+2?<|user_end|><|assistant_start|>4<|assistant_end|>
  ↓         ↓                        ↓             ↓            ↓      ↓
 -1        -1                       -1            -1            4      END
                                                               ↑       ↑
                                            只有这两个 token 参与 loss!

学习的能力

  • ✅ 只学习生成回复(assistant 的内容)
  • ❌ 不学习理解问题(user 的内容不参与梯度)
  • ❌ 不学习对话格式(特殊 token 不参与梯度)

9.4 为什么要这样设计?

Mid Training:为什么要训练所有 token?

原因 1:从零开始学习对话

在 Base 训练后,模型只见过纯文本:

The capital of France is Paris. It is a beautiful city...

从未见过对话格式:

<|user_start|>What is the capital of France?<|user_end|>
<|assistant_start|>The capital of France is Paris.<|assistant_end|>

如果在 Mid 阶段就使用 mask(只训练 assistant):

  • ❌ 模型不知道 <|user_start|> 后面应该是什么
  • ❌ 模型不知道 user 说完后应该接 <|user_end|>
  • ❌ 模型不知道 user 结束后应该接 <|assistant_start|>

使用所有 token 训练的好处

  • ✅ 学会对话的完整流程
  • ✅ 学会在正确的时机切换角色
  • ✅ 学会理解 user 的输入(通过预测下一个 token)

原因 2:建立上下文理解能力

# 训练目标:预测下一个 token
输入<|user_start|>What is 2+
目标2

输入<|user_start|>What is 2+2
目标:?

输入<|user_start|>What is 2+2?
目标<|user_end|>

通过预测 user 输入的每一个 token,模型学会了:

  • 理解问题的结构
  • 识别问题何时结束
  • 准备生成回复

原因 3:多任务混合的需要

Mid Training 的数据混合了多种任务:

TaskMixture([
    SmolTalk,        # 通用对话
    MMLU,            # 多选题
    GSM8K,           # 数学推理
    SpellingBee,     # 字符级任务
])

每种任务的格式可能不同,模型需要学会:

  • 识别当前是哪种任务
  • 理解不同任务的输入格式
  • 生成符合任务要求的输出

如果只训练 assistant 的回复

  • 模型可能学不会区分不同任务类型
  • 模型可能不理解输入的具体格式要求

SFT Training:为什么要使用 Mask?

原因 1:防止遗忘

在 Mid Training 后,模型已经学会了:

  • ✅ 对话格式
  • ✅ 任务类型识别
  • ✅ 基础推理能力

SFT 的目标不是重新学习这些,而是:

  • 🎯 精细调整输出风格
  • 🎯 提升指令遵循能力
  • 🎯 优化输出格式

如果 SFT 仍然训练所有 token

  • ❌ 模型可能”重新学习”对话格式(浪费计算)
  • ❌ 可能破坏 Mid 阶段学到的知识
  • ❌ 在小数据集上容易过拟合

使用 Mask 的好处

  • ✅ 只更新 assistant 的生成能力
  • ✅ 保留 Mid 阶段的对话理解能力
  • ✅ 配合小学习率(0.02x),防止遗忘

原因 2:数据效率

SFT 只有 23K 样本(vs Mid 的 850K):

  • 如果训练所有 token,数据量太少
  • 只训练 assistant 回复,提高数据利用效率

原因 3:避免记忆训练数据

# 问题:如果训练所有 token
输入<|user_start|>What is the capital of France?
目标<|user_end|>

# 风险:模型可能记住了这个问题
# 测试时看到类似问题,直接输出训练数据的答案

使用 Mask

  • 模型不会记住问题的具体内容
  • 只学习如何生成合适的回答
  • 泛化能力更强

9.5 数据流完整对比

Mid Training 数据流

# Step 1: Tokenize(返回 mask 但丢弃)
conversation = {
    "messages": [
        {"role": "user", "content": "What is 2+2?"},
        {"role": "assistant", "content": "4"}
    ]
}

ids, _ = tokenizer.render_conversation(conversation)
# ids = [BOS, USER_START, W, h, a, t, ..., USER_END, 
#        ASST_START, 4, ASST_END]

# Step 2: 拼接到 token buffer
token_buffer.extend(ids)

# Step 3: 生成 batch(没有 mask)
inputs = [USER_START, W, h, a, t, ...]  # (B, T)
targets = [W, h, a, t, ..., USER_END, ASST_START, 4, ASST_END]

# Step 4: 计算 loss(所有 token)
logits = model(inputs)  # (B, T, V)
loss = F.cross_entropy(
    logits.view(-1, V),
    targets.view(-1)
)
# 每个 token 都有梯度:
# - W, h, a, t (user 的输入)
# - 4 (assistant 的输出)
# - 所有特殊 token

梯度更新的内容

∂Loss/∂θ = 梯度来自所有 token 的预测误差

包括:
- 预测 "What" 中的 "h" 的误差
- 预测 "?" 后应该接 USER_END 的误差
- 预测 USER_END 后应该接 ASST_START 的误差
- 预测 ASST_START 后应该接 "4" 的误差

SFT Training 数据流

# Step 1: Tokenize(保留 mask)
ids, mask = tokenizer.render_conversation(conversation)
# ids = [BOS, USER_START, W, h, a, t, ..., USER_END, 
#        ASST_START, 4, ASST_END]
# mask = [0, 0, 0, 0, 0, 0, ..., 0, 
#         0, 1, 1]
#                ↑  ↑  只有这两个是 1

# Step 2: 应用 mask(设置 ignore_index)
targets = ids[1:]
mask_tensor = mask[1:]
targets[mask_tensor == 0] = -1
# targets = [-1, -1, -1, -1, -1, ..., -1, 
#            -1, 4, ASST_END]
#                ↑   ↑  只有这两个保留

# Step 3: 计算 loss(只在 mask=1 的 token 上)
loss = F.cross_entropy(
    logits.view(-1, V),
    targets.view(-1),
    ignore_index=-1  # ← 关键!
)
# 只有 targets != -1 的位置有梯度

梯度更新的内容

∂Loss/∂θ = 梯度只来自 mask=1 的 token

包括:
- 预测 "4" 的误差(assistant 的回复)
- 预测 ASST_END 的误差(结束 token)

不包括:
- User 输入的任何 token(mask=0)
- 特殊 token 如 BOS, USER_START, USER_END(mask=0)

9.6 实验验证

假设实验 1:如果 Mid Training 也使用 Mask?

可能的结果:

  • ❌ 模型学不会对话格式(没见过 user 输入的训练信号)
  • ❌ 生成时可能在错误的位置停止(不知道什么时候该结束)
  • ❌ 无法处理多轮对话(不理解对话的流程)

假设实验 2:如果 SFT Training 不使用 Mask?

可能的结果:

  • ❌ 过拟合训练数据的问题(记住了具体问题)
  • ❌ 破坏 Mid 阶段的知识(重新学习对话格式)
  • ❌ 泛化能力下降(在新问题上表现差)

9.7 设计哲学总结

┌─────────────────────────────────────────────────────────┐
│                    训练阶段的演变                        │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  Base Training: 学习语言                                 │
│  ↓                                                      │
│  训练:所有 token(纯文本,无对话格式)                   │
│                                                         │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  Mid Training: 学习对话流                                │
│  ↓                                                      │
│  训练:所有 token(包括 user 和 assistant)              │
│  目标:理解对话格式 + 学会生成回复                        │
│                                                         │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  SFT Training: 精细调整回复                              │
│  ↓                                                      │
│  训练:只有 assistant 的 token(使用 mask)              │
│  目标:优化输出风格 + 指令遵循                           │
│                                                         │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  RL Training: 强化推理能力                               │
│  ↓                                                      │
│  训练:只有 assistant 的 token(使用 mask)              │
│  目标:多步推理 + 试错探索                               │
│                                                         │
└─────────────────────────────────────────────────────────┘

核心原则

  1. Mid 阶段:从零学习对话 → 需要完整的训练信号
  2. SFT/RL 阶段:优化已有能力 → 只需要针对性的训练信号

类比

  • Mid Training 像学开车:需要学习方向盘、油门、刹车、换挡…(所有操作)
  • SFT Training 像参加驾考:只需要优化驾驶技巧,不需要重新学基本操作

9.8 代码位置参考

# Mask 的生成:nanochat/tokenizer.py 第 258-358 行
def render_conversation(self, conversation):
    """返回 ids 和 mask"""
    ids, mask = [], []
    # ... 详细的 mask 生成逻辑

# Mid Training 丢弃 mask:scripts/mid_train.py 第 133 行
ids, _ = tokenizer.render_conversation(conversation)
#      ↑ 下划线表示丢弃返回值

# SFT Training 使用 mask:scripts/chat_sft.py 第 123 行
ids, mask = tokenizer.render_conversation(doc)
# 后续在 collate_and_yield 中将 mask=0 的位置设为 -1

总结

Mid Training 的三个关键作用:

  1. 能力扩展 - 从通用语言理解到特定任务能力
  2. 数据桥接 - 从无结构文本到结构化对话
  3. 技能注入 - 工具使用、推理、拼写等

关键设计原则(新增):

  • 📊 大规模混合(850K 对话)
  • 📊 确定性 shuffle(seed=42)
  • 📊 渐进式学习率调度
  • 📊 高质量数据混合
  • 📊 训练所有 token(不使用 mask) ← 核心差异!

Mask 的使用时机:

  • ❌ Base Training:无对话格式,不需要 mask
  • ❌ Mid Training:学习对话流,不使用 mask
  • ✅ SFT Training:精细调整,使用 mask
  • ✅ RL Training:强化学习,使用 mask

训练流程:

Base Model (通用能力)
    ↓
+ Structured Tasks (结构化任务)
    ↓ [训练所有 token,学习对话流]
Mid Model (扩展能力)
    ↓
准备好进行 SFT [只训练 assistant 回复]

下一步:

  • SFT:教模型遵循指令(使用 mask)
  • RL:通过强化学习优化行为(使用 mask)

本文档基于 nanochat 项目分析生成
适合 LLM 初学者理解 Mid Training 的完整流程
更新时间: 2025年12月22日

标签: