Self-Forcing 与 Self-Forcing++ 演进之路

5 分钟阅读时长

发布时间:

self-forcing

self-forcing++

Self-Forcing & Self-Forcing++:从训练-推理鸿沟到分钟级实时流式视频生成(技术拆解)

核心关键词:Autoregressive Video Diffusion、Exposure Bias、KV Cache、Rolling KV、Few-step Diffusion、Distribution Matching Distillation、Minute-scale Generation、FPS / Latency


1. 背景:为什么“自回归视频扩散”会在长视频上崩?

自回归视频生成通常写成:

\[p_\theta(x_{1:N})=\prod_{i=1}^{N} p_\theta(x_i \mid x_{<i})\]

训练-推理鸿沟(train-test gap / exposure bias)来自于:

  • 训练阶段常用 Teacher Forcing (TF):条件上下文用真实历史帧 \(x_{<i}\);
  • 推理阶段必须用 模型自己生成的历史帧 \(\hat x_{<i}\) 当条件。

结果:误差会随时间累积,越往后越偏离“训练时见过的条件分布”,最终出现:

  • 画面冻结、漂移、闪烁
  • 语义/角色一致性断裂
  • 长程结构坍塌(超过训练片段长度尤其明显)

Self-Forcing 系列的中心思想是:训练时就让模型以“真实推理方式”rollout,并在 rollout 分布上做整体分布对齐,从根上缩小 train-test gap。


2. 统一符号与张量维度(工程实现对照)

为了能落到代码层面,先把维度讲清楚(论文通常不写死具体值,用符号化更通用)。

2.1 基础维度

  • batch:\(B\)
  • 视频长度(训练/推理生成的 latent 帧数):\(N\)
  • 每帧 RGB 分辨率:\(H \times W\)
  • VAE latent 分辨率:\(H_\ell \times W_\ell\)
  • latent 通道数:\(C_\ell\)
  • Transformer hidden dim:\(D\)
  • attention head 数:\(n_h\),每头维度:\(d_h\),满足 \(D=n_h\cdot d_h\)

2.2 张量表示

  • 像素空间视频:\(x^{pix}_{1:N} \in \mathbb{R}^{B\times N\times H\times W\times 3}\)
  • latent 空间视频:\(x_{1:N} \in \mathbb{R}^{B\times N\times C_\ell\times H_\ell\times W_\ell}\)

把每帧 latent patchify / flatten 成 token:

  • 每帧 token 数:\(P=H_\ell\cdot W_\ell\)
  • 单帧 token:\(\text{tok}(x_i) \in \mathbb{R}^{B\times P\times D}\)
  • 历史上下文 token 长度:\(S_{<i}=(i-1)\cdot P\)

2.3 KV Cache 形状(典型实现)

对第 \(l\) 层:

  • \(K^{(l)},V^{(l)} \in \mathbb{R}^{B\times n_h\times S\times d_h}\)
    其中 \(S\) 是当前 cache 中累计 token 数(rolling KV 会限制 \(S\le L\cdot P\))。

3. Base:Self-Forcing(核心机制 = 推理式 rollout + 视频级分布对齐)

Self-Forcing 的两个关键点:

1) 训练时真的按推理方式做自回归 rollout
每一步生成当前帧时,条件上下文来自模型自己的历史输出。

2) 损失从帧级 denoising(MSE)升级为视频级分布匹配(holistic distribution matching)
直接对齐“模型会生成的视频分布”与“真实数据分布”。


4. TF / DF:传统训练方式的数据流与损失(对比基线)

4.1 扩散前向加噪(每帧)

\(x_t = \alpha_t x + \sigma_t \epsilon,\quad \epsilon \sim \mathcal{N}(0,I)\)

维度不变:\(x_t,\epsilon \in \mathbb{R}^{B\times C_\ell\times H_\ell\times W_\ell}\)。

4.2 Teacher Forcing (TF) 与 Diffusion Forcing (DF)

  • TF:条件上下文 \(c = x_{<i}\)(干净 GT 历史帧)
  • DF:条件上下文 \(c = \{x^j_{t_j}\}_{j<i}\)(带噪历史帧)

4.3 帧级 denoising 损失(典型)

模型预测噪声 \(\hat\epsilon_\theta\),损失:

\[\mathcal{L}^{DM}_\theta=\mathbb{E}_{x,t,\epsilon}\left[w_t\left\|\hat\epsilon_\theta-\epsilon\right\|_2^2\right]\]

痛点:即便 DF 用带噪上下文,也仍然不等价于推理时“模型自生成历史帧”分布;训练-推理分布仍然错位。


5. Self-Forcing 训练:推理式 rollout + 可训练的资源控制

5.1 训练核心:从模型分布 rollout

训练时生成整段:

\[\{x^\theta_{1:N}\}\sim p_\theta(x_{1:N})\]

即每一帧(或 chunk)的条件来自前面模型输出,并用 KV cache 累积上下文。

5.2 few-step diffusion:把每帧扩散步数压到很少

对每帧生成不走长链路(几十/上百步),而走少步时间网格:

\[0=t_0<t_1<\dots<t_T=1\]

其中 \(T\) 通常很小(工程上可做到 4 steps)。

每帧内部 denoise loop(概念):

  1. 初始化 \(x_{t_T}\sim\mathcal{N}(0,I)\)
  2. for \(j=T..1\):
    • \[\hat x_0 \leftarrow G_\theta(x_{t_j}, t_j, KV)\]
    • 若 \(j>1\):重加噪得到 \(x_{t_{j-1}}\)

5.3 训练能跑起来:梯度截断 + 随机截断步

直接对“AR rollout × diffusion steps”全图反传会爆显存/爆算力。

常用做法(自回归 + few-step 场景很有效):

  • 随机采样一个截断步 \(s\in\{1..T\}\);
  • 只在 \(j=s\) 的那一次前向开启梯度,其余 denoise steps 用 no_grad
  • 同时对历史帧的 KV cache 做 stop_grad,避免跨帧反传。

工程直觉:让每个时间步都有机会被监督,但每次只训练一小段图


6. Self-Forcing 的“视频级损失”:三种可选目标(DMD / SiD / GAN)

Self-Forcing 的关键升级:不再只优化逐帧 MSE,而是直接对齐整体视频分布。

典型做法:对真实视频与生成视频都加同一层 forward diffusion 噪声,比较噪声域分布 \(p_{*,t}\)。

6.1 DMD:反向 KL(Reverse KL)

目标形式:

\[\mathbb{E}_t\left[D_{KL}\left(p_{\theta,t}\,\|\,p_{data,t}\right)\right]\]

实现上常通过 teacher / score 估计,把“分布对齐梯度”转为可训练的方向信号。

6.2 SiD:Fisher divergence(score matching 风格)

\(\mathbb{E}_{t,\,x\sim p_{\theta,t}}\left[\left\|\nabla\log p_{\theta,t}(x)-\nabla\log p_{data,t}(x)\right\|_2^2\right]\)

6.3 GAN:JS 近似的对抗分布对齐

训练判别器区分真实/生成视频,生成器(rollout 的 AR diffusion 模型)通过对抗学习逼近数据分布。

三者本质:都在“推理分布上的样本”层面做 holistic distribution alignment,而不是只做帧级 denoising。


7. 推理:如何做到“实时流式”?

Self-Forcing 的推理阶段就是:自回归 + few-step + KV cache

7.1 推理数据流(frame-wise / chunk-wise)

  • 输入:prompt conditioning(文本 embedding 等)
  • 初始化 KV cache 为空
  • 对 i=1..N:
    • 初始化当前帧/当前 chunk 的噪声 latent:\(x_{t_T}\)
    • 运行 \(T\) 次 denoise(T 很小,如 4)
    • 得到 \(\hat x_0\)(当前输出)
    • 计算当前输出 token 的 K/V 并 append 到 cache
    • 立即解码输出(可流式边生成边播放)

流式关键:每生成一帧/一段,就能立刻输出,不需要等完整视频生成结束。


8. 为什么可以跑到 FPS≈17:三个增益叠加

把“视频扩散很慢”的主要开销拆成两类: 1) 扩散步数(网络前向次数) 2) 注意力计算(上下文长、滑窗重算等)

Self-Forcing 的 17 FPS 主要来自:

8.1 Few-step diffusion:把网络前向次数从“几十/上百”降到“4”

扩散推理时间大体随步数线性增长。把 \(T\) 压到 4,吞吐直接提升一个数量级。

8.2 因果注意力 + KV cache:历史不重算

每次只需要为“新 token”计算 Q,并与 cache 中的 K/V 做 attention;历史帧的 K/V 不再重复计算。

8.3 Chunk-wise:一次生成多个 latent 帧,摊薄固定开销

如果每次生成的是一个 chunk(例如 3 个 latent frames),那么:

  • 调度器开销、cache 管理开销、部分 embedding 开销被摊薄;
  • 吞吐更高、首帧延迟仍可控。

9. Base 的长视频问题:Rolling KV 的训练/推理不一致

即使推理用 rolling KV 来保持长视频复杂度稳定,如果训练阶段使用的是不同的 cache 策略(例如固定窗口或不同的 cache 行为),模型依然会遇到 cache-state mismatch:推理时遇到的“滚动上下文状态”在训练时没见过,长视频仍可能闪烁或退化。

这正是升级版 Self-Forcing++ 要解决的核心之一。


10. 升级版:Self-Forcing++(分钟级高质量长视频的关键改动)

Self-Forcing++ 的目标:把“只会 5 秒”升级为“分钟级”(100s 甚至更长)仍保持质量与一致性。

核心改动可以概括为三件套: 1) 长 rollout(N ≫ K) 2) Extended-DMD:从长 rollout 中随机抽窗做分布蒸馏 3) 训练阶段也使用 rolling KV cache 并辅以一个可选增强: 4) GRPO + optical-flow reward(长期平滑性)


11. Self-Forcing++ 训练数据流(含维度)

11.1 关键长度标量

  • \(N\):student rollout 总长度(很长,分钟级)
  • \(K\):teacher 能覆盖的窗口长度(短,通常对应训练上限,比如 ~5s 的 latent 长度)
  • \(L\):rolling KV 的 cache 长度上限(以帧或 token 计)
  • chunk size:一次生成 \(m\) 个 latent frames(如 m=3)

11.2 Step-by-step(核心流程)

Step A:student 长 rollout(使用 rolling KV)

生成:

\[V \in \mathbb{R}^{B\times N\times C_\ell\times H_\ell\times W_\ell}\]

并且 KV cache 在整个 rollout 中滚动维护,保持推理一致的 cache-state 分布。

Step B:从长轨迹中均匀采样一个连续窗口

采样起点 \(i\sim\{1,\dots,N-K+1\}\):

\[W = V[i:i+K-1] \in \mathbb{R}^{B\times K\times C_\ell\times H_\ell\times W_\ell}\]

Step C:Backward Noise Initialization(后向噪声初始化)

直觉:窗口蒸馏不能从“纯随机噪声”起步,否则和长上下文解耦。
做法:对窗口 \(W\) 重新注入噪声,构造与扩散时间步一致的输入 \(x_t(W)\):

\[x_t(W)\in \mathbb{R}^{B\times K\times C_\ell\times H_\ell\times W_\ell}\]

这样 teacher / student 在同一“长上下文一致”的带噪窗口上对齐。

Step D:Teacher / Student 在窗口上输出 score / noise 预测

  • teacher 输出:\(s^T(x_t(W),t)\)
  • student 输出:\(s^S_\theta(x_t(W),t)\)
    二者维度同 \(x_t(W)\)。

Step E:Extended-DMD 更新

在随机窗口上做分布匹配蒸馏(本质是把 teacher 的“局部纠错能力”教给 student,使其能从长 rollout 的退化态恢复)。


12. Self-Forcing++ 损失函数:Extended-DMD +(可选)GRPO

12.1 Extended-DMD(核心监督)

在随机窗口上最小化 student 与 teacher 的分布差异(可写成 KL 形式的目标),并用 DMD 的梯度估计实现可训练更新。

关键点:监督覆盖的是“长 rollout 中任意位置的短窗口”,因此 student 会学到:

  • 错误累积态下的恢复能力
  • 分钟级生成中局部片段仍像“真实短视频分布”

12.2 GRPO + Optical-flow reward(可选增强)

分钟级生成仍可能出现长期一致性衰减(物体突然出现/消失、突兀切场)。
引入 RL 微调,用光流相关的 reward 作为时间平滑 proxy,进一步优化长程稳定性。


13. Base vs ++:对比总结(关键差异点)

维度Self-Forcing (base)Self-Forcing++
训练覆盖时间短(teacher 覆盖长度内)长 rollout 到 \(N\gg K\),从长序列抽窗蒸馏
蒸馏/对齐区域生成段整体(但主要是短域)任意位置窗口(覆盖退化态)
噪声初始化常规扩散起点Backward Noise Init 保持窗口与长上下文一致
KV 策略推理 rolling,但训练可能不同训练与推理都 rolling KV(消 cache-state mismatch)
长视频退化处理依赖 base 对齐能力显式学习“从退化态纠错恢复” + 可选 RL 平滑

一句话:
Self-Forcing++ = “先让学生在真实推理形态下长跑出错,再用 teacher 在短窗口里纠错式蒸馏,把恢复能力学回去”,并让训练/推理 cache 行为完全一致。


14. 工程实现要点(能跑到流式实时的关键细节)

14.1 推理循环(伪代码)

KV = init_empty_cache()

for i in range(num_steps_over_time):  # frame-wise or chunk-wise
    x_t = randn_latent()  # [B, Cℓ, Hℓ, Wℓ] or [B, m, Cℓ, Hℓ, Wℓ]
    for j in reversed(range(T)):      # T is small, e.g. 4
        x0_hat = model(x_t, t_j, KV)  # predict x0 (or eps)
        if j > 0:
            x_t = add_noise(x0_hat, t_{j-1})  # one-step scheduler
    out = x0_hat
    KV.append(compute_kv(out))
    KV = rolling_evict_if_needed(KV, L)
    stream_decode_and_send(out)  # optional: decode immediately

14.2 训练的资源控制(伪代码骨架)

    # rollout phase (no_grad for most)
    KV = init_empty_cache()
    V = []
    for i in range(N):
        out = autoregressive_generate_one_step(KV, T=few_steps, grad=False)
        V.append(out)
        KV.append(compute_kv(out))
        KV = rolling_evict_if_needed(KV, L)

    # sample window
    i0 = uniform(0, N-K)
    W = stack(V[i0:i0+K])

    # backward noise init
    t = sample_timestep()
    Wt = add_noise(W, t)

    # distillation update (enable grad only here)
    loss = extended_dmd_loss(student(Wt,t), teacher(Wt,t))
    loss.backward()
    opt.step()

14.3 为什么 rolling KV 能让长视频不降速?

不 rolling:上下文长度随时间增长,注意力开销越来越大;

rolling:上下文 token 长度上限固定为 \(L\cdot P\),每步只做“固定上限”的 attention,长视频复杂度稳定。

15. 结语:一条清晰的路线图

要实时流式:few-step(如 4 steps) + causal attention KV cache +(可选)chunk-wise

要长视频不崩:rolling KV + 在训练中也使用 rolling KV(避免 cache-state mismatch)

要分钟级高质量:长 rollout + 随机抽窗 Extended-DMD + backward noise init(保持上下文一致)

要更平滑:在分钟级后训练阶段加上 GRPO / flow reward 等时间一致性奖励

标签: