note

  • 传统 KD:教师模型生成一些训练输出,生模型模仿这些输出。
  • GKD(广义 KD):
    • 学生模型先自己生成一些序列,然后用教师模型对这些学生生成的序列进行打分或提供反馈,学生模型基于这些反馈进一步调整自己。
    • 这可以更好解决“训练时只学教师输出,而测试时要靠自己生成输出”之间的分布不一致问题。
  • GKD = 用 teacher 的“软分布”监督 student,但 teacher 的数据来源可以混合,一部分来自真实数据,一部分来自student自己采样的数据。
  • 为什么 RL 探索成本高,而蒸馏可以避免?
    • 强化学习(RL) 之所以训练成本高,是因为它需要在庞大的"策略空间"中通过大量随机尝试和试错来摸索有效策略,这个过程消耗了绝大多数计算资源。
    • On-Policy 蒸馏 则完全不同——教师模型已经掌握了正确的策略,能够直接向学生模型展示"标准答案",让学生无需盲目探索,只需模仿学习即可。

一、Generalized Knowledge Distillation

GKD(Generalized Knowledge Distillation,广义知识蒸馏)训练算法由论文 On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes 提出。该算法通过结合离线(off-policy)和在线(on-policy)学习策略,将教师模型的知识迁移到学生模型中。

1、损失函数

在这里插入图片描述

  • 给定techer模型、student模型、SFT数据
  • 超参设置:λ比例、Divergence D、学习率
  • 走老师路线(u < λ):从输入X中抽一个prompt,学生先自己生成一个答案,组成batch B
  • 走真实数据路线(u ≥ λ):就是普通的SFT
  • 核心:θ ← θ − η * ∇θ D(π_T || π_S^θ)(y|x)
    • 比较老师和学生在 (x, y) 上的输出分布

当给定输入序列 x x x 与输出序列 y y y,GKD 的损失函数可以写为:

L GKD ( x , y ) = ∑ t = 1 ∣ y ∣ D ( P teacher ( ⋅ ∣ x , y < t ) , P student ( ⋅ ∣ x , y < t ) ) \mathcal{L}_{\text{GKD}}(x, y) = \sum_{t=1}^{|y|} D(P_{\text{teacher}}(\cdot | x, y_{<t}), P_{\text{student}}(\cdot | x, y_{<t})) LGKD(x,y)=t=1yD(Pteacher(x,y<t),Pstudent(x,y<t))

其中:

  • y < t = ( y 1 , y 2 , … , y t − 1 ) y_{<t} = (y_1, y_2, \ldots, y_{t-1}) y<t=(y1,y2,,yt1):前 t − 1 t-1 t1 个 token 的序列
  • P teacher ( ⋅ ∣ x , y < t ) P_{\text{teacher}}(\cdot | x, y_{<t}) Pteacher(x,y<t):教师模型在给定上下文 x , y < t x, y_{<t} x,y<t 时的输出概率分布
  • P student ( ⋅ ∣ x , y < t ) P_{\text{student}}(\cdot | x, y_{<t}) Pstudent(x,y<t):学生模型在给定上下文 x , y < t x, y_{<t} x,y<t 时的输出概率分布
  • D ( ⋅ , ⋅ ) D(\cdot, \cdot) D(,):散度函数,用于度量两个概率分布之间的差异性

2、散度度量函数

(1)KL 散度(Kullback-Leibler Divergence)

KL 散度是衡量两个概率分布 P P P Q Q Q 之间差异的非对称度量:

KL ( P ∥ Q ) = ∑ v P ( v ) log ⁡ P ( v ) Q ( v ) = E v ∼ P [ log ⁡ P ( v ) Q ( v ) ] \text{KL}(P \| Q) = \sum_v P(v) \log \frac{P(v)}{Q(v)} = \mathbb{E}_{v \sim P}\left[\log \frac{P(v)}{Q(v)}\right] KL(PQ)=vP(v)logQ(v)P(v)=EvP[logQ(v)P(v)]

(2)Forward KL 与 Reverse KL

在知识蒸馏中,根据 KL 散度中两个分布的顺序不同,有两种选择:

Forward KL(前向 KL)

KL ( P student ∥ P teacher ) = ∑ v P student ( v ) log ⁡ P student ( v ) P teacher ( v ) \text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)} KL(PstudentPteacher)=vPstudent(v)logPteacher(v)Pstudent(v)

特性:Mode-seeking(寻模)

  • 期望在学生分布下计算
  • 学生模型倾向于集中在教师模型的峰值区域(高概率区域)
Reverse KL(反向 KL)

KL ( P teacher ∥ P student ) = ∑ v P teacher ( v ) log ⁡ P teacher ( v ) P student ( v ) \text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)} KL(PteacherPstudent)=vPteacher(v)logPstudent(v)Pteacher(v)

特性:Mode-covering(覆模)

  • 期望在教师分布下计算
  • 学生模型倾向于覆盖教师的整个分布(包括低概率区域)

(3)广义 Jensen-Shannon 散度(Generalized JSD)

GKD 使用广义 JSD 作为核心度量,通过参数 β ∈ [ 0 , 1 ] \beta \in [0, 1] β[0,1] 在 Forward KL 和 Reverse KL 之间进行平滑插值

对于两个概率分布 P P P Q Q Q,广义 JSD 定义为:

D JSD ( β ) ( P , Q ) = β ⋅ KL ( P ∥ M ) + ( 1 − β ) ⋅ KL ( Q ∥ M ) D_{\text{JSD}(\beta)}(P, Q) = \beta \cdot \text{KL}(P \| M) + (1-\beta) \cdot \text{KL}(Q \| M) DJSD(β)(P,Q)=βKL(PM)+(1β)KL(QM)

其中混合分布 M M M 定义为:

M = β ⋅ P + ( 1 − β ) ⋅ Q M = \beta \cdot P + (1-\beta) \cdot Q M=βP+(1β)Q

  • β = 0.5 \beta = 0.5 β=0.5 时,退化为标准的对称 JSD
  • 通过调节 β \beta β,可以在 Mode-seeking 和 Mode-covering 之间权衡

在 GKD 中,我们令 P = P teacher P = P_{\text{teacher}} P=Pteacher Q = P student Q = P_{\text{student}} Q=Pstudent,因此:

D JSD ( β ) ( P teacher , P student ) = β ⋅ KL ( P teacher ∥ M ) + ( 1 − β ) ⋅ KL ( P student ∥ M ) D_{\text{JSD}(\beta)}(P_{\text{teacher}}, P_{\text{student}}) = \beta \cdot \text{KL}(P_{\text{teacher}} \| M) + (1-\beta) \cdot \text{KL}(P_{\text{student}} \| M) DJSD(β)(Pteacher,Pstudent)=βKL(PteacherM)+(1β)KL(PstudentM)

其中 M = β ⋅ P teacher + ( 1 − β ) ⋅ P student M = \beta \cdot P_{\text{teacher}} + (1-\beta) \cdot P_{\text{student}} M=βPteacher+(1β)Pstudent

对极端情况( β = 0 \beta = 0 β=0 β = 1 \beta = 1 β=1),直接计算单个 KL 散度:

  • β = 0 \beta = 0 β=0 时:直接定义 D = KL ( P teacher ∥ P student ) D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}}) D=KL(PteacherPstudent)(Reverse KL,Mode-covering)
  • β = 1 \beta = 1 β=1 时:直接定义 D = KL ( P student ∥ P teacher ) D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}}) D=KL(PstudentPteacher)(Forward KL,Mode-seeking)
  • 0 < β < 1 0 < \beta < 1 0<β<1 时:使用上述混合分布公式进行插值

通过调节 β \beta β 参数,可以在不同的散度度量之间进行插值,当 β = 0.5 \beta = 0.5 β=0.5 时,散度为标准的对称 JSD。

3、三种训练模式

GKD训练具有三种训练模式,区别在于输出序列 y y y 的来源。

模式选择逻辑

训练时,每个样本按照以下优先级选择模式:

# 伪代码:模式选择逻辑
if random() < lmbda:
    # Mode 1: On-Policy 学习,由学生模型采样输出序列
    y = student.generate(x)
    source = "student"
elif seq_kd:
    # Mode 2: Sequential KD,由教师模型采样输出序列
    y = teacher.generate(x)
    source = "teacher"
else:
    # Mode 3: Off-Policy 学习,使用数据集中的输出序列
    y = y_ground_truth
    source = "dataset"

# 相同的损失函数
loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))

Mode 1: On-Policy 学习

设置参数lambda, 以概率 λ \lambda λ 触发,使用学生模型采样 y ∼ P student ( ⋅ ∣ x ) y \sim P_{\text{student}}(\cdot | x) yPstudent(x)

  • 学生模型从自己生成的序列中学习
  • 暴露在自己可能犯的错误中,学会自我纠正和错误恢复
  • 对齐训练分布与推理分布
  • 提升模型的鲁棒性和实际应用表现

适用场景

  • 学生模型已有一定生成能力
  • 希望提升模型在真实推理场景下的表现

Mode 2: Sequential KD(seq_kd=True 且未触发 on-policy)

设置参数 seq_kd=True, 当未触发 on-policy 时,使用教师模型采样

数据来源 y ∼ P teacher ( ⋅ ∣ x ) y \sim P_{\text{teacher}}(\cdot | x) yPteacher(x)

Mode 3: Off-Policy 学习(其他情况)

数据来源 y = y ∗ ∼ Dataset y = y^* \sim \text{Dataset} y=yDataset

  • 学生模型从数据集的标注序列中学习

4、参数设置

可以通过设置以下参数进行 GKD 训练:

参数 类型 默认值 取值范围 说明
--teacher_model str 必需 - 教师模型路径或模型 ID
--beta float 0.5 [0.0, 1.0] 散度插值系数
• 0.0: Reverse KL (覆模,更多样)
• 0.5: JSD (平衡,推荐)
• 1.0: Forward KL (寻模,更专注)
--lmbda float 0.5 [0.0, 1.0] On-Policy 学习触发概率
• 0.0: 纯 Off-Policy
• 0.5: 混合策略 (推荐)
• 1.0: 纯 On-Policy
--seq_kd bool False True/False 是否使用教师生成序列
• False: 非 on-policy 时使用数据集
• True: 非 on-policy 时使用教师生成
--temperature float 0.9 > 0 生成采样温度,控制随机性
--max_completion_length int 512 > 0 生成时的最大 token 数

5、采样加速

在 GKD 训练中,涉及到两种在线采样的情况:

  1. 学生模型采样(当 lmbda > 0):以 λ \lambda λ 概率触发学生模型采样
  2. 教师模型采样(当 seq_kd=True):以 1 − λ 1-\lambda 1λ 概率触发教师模型采样

由于采样过程会显著减慢训练速度,可参考以下两种加速方案:

方案 1:学生模型采样加速

要求:swift >= 3.10.dev

使用 vLLM 作为推理后端来加速学生模型采样,支持两种部署模式,与 GRPO 一致,参考GRPO文档, 相关参数参考GRPO vLLM 参数

注意:vLLM 加速仅适用于学生模型的 on-policy 采样(lmbda > 0)。教师模型的 sequential KD 采样(seq_kd=True)目前仍使用 PyTorch,建议使用预采样方案。

训练脚本参考这里

# 4 * 45GiB, 10.29s/it
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
MASTER_PORT=29501 \
NPROC_PER_NODE=4 \
swift rlhf \
    --rlhf_type gkd \
    --model OpenGVLab/InternVL3-2B-Pretrained \
    --teacher_model OpenGVLab/InternVL3-8B \
    --dataset 'modelscope/coco_2014_caption:validation#2000' \
    --load_from_cache_file true \
    --split_dataset_ratio 0.01 \
    --train_type full \
    --seq_kd true \
    --torch_dtype bfloat16 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --learning_rate 1e-5 \
    --freeze_vit true \
    --gradient_accumulation_steps 1 \
    --eval_steps 50 \
    --save_steps 50 \
    --save_total_limit 2 \
    --deepspeed zero2 \
    --attn_impl flash_attn \
    --logging_steps 5 \
    --max_length 4096 \
    --max_completion_length 512 \
    --output_dir output \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4 \
    --dataset_num_proc 4 \
    --save_only_model true

方案 2:教师模型预采样

对于教师模型采样(seq_kd=True),推荐使用 预采样 方式:先用教师模型离线生成高质量数据,再进行训练。

步骤 1:使用教师模型生成数据

export teacher_model='OpenGVLab/InternVL3-8B'

NPROC_PER_NODE=4 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
swift infer \
    --model $teacher_model \
    --infer_backend vllm \
    --val_dataset 'modelscope/coco_2014_caption:validation#5000' \
    --vllm_gpu_memory_utilization 0.9 \
    --vllm_max_model_len 8192 \
    --max_new_tokens 2048 \
    --write_batch_size 1000 \
    --result_path teacher_generated_data.jsonl

步骤 2:使用预生成数据训练

swift rlhf \
    --rlhf_type gkd \
    --model OpenGVLab/InternVL3-2B-Pretrained \
    --teacher_model $teacher_model \
    --dataset 'teacher_generated_data.jsonl' \
    --seq_kd false \
    ...

训练脚本参考这里

二、On-Policy Distillation

1、训练方法

Off-Policy(离策略):向别人学习
On-Policy(同策略):从自己的经验学习

Thinking Machines Lab 提出的 On-Policy Distillation(同策略蒸馏):保持 On-Policy 的优势(学生在自己的状态中学习),但用 Off-Policy 的方式提供密集反馈(教师对每一步打分),即让学生模型自己做题(on-policy),但请教师模型对每一步都打分(dense reward)。

具体而言:
1、由学生自己生成答案:用学生模型(如 Qwen3-8B)解数学题,生成完整的推理过程
2、教师逐步评分:用强大的教师模型(如 Qwen3-32B)对学生的每个动作(token)进行评分
3、精准学习纠错:学生模型知道自己哪一步对、哪一步错,针对性改进

On-Policy 蒸馏使用反向 KL 散度作为损失函数:
K L ( π θ ∥ π teacher  ) = E x ∼ π θ [ log ⁡ π θ ( x t + 1 ∣ x 1 , . t ) − log ⁡ π teacher  ( x t + 1 ∣ x 1 , . t ) ] \mathrm{KL}\left(\pi_\theta \| \pi_{\text {teacher }}\right)=\mathbb{E}_{x \sim \pi_\theta}\left[\log \pi_\theta\left(x_{t+1} \mid x_{1, . t}\right)-\log \pi_{\text {teacher }}\left(x_{t+1} \mid x_{1, . t}\right)\right] KL(πθπteacher )=Exπθ[logπθ(xt+1x1,.t)logπteacher (xt+1x1,.t)]

注意事项:
从学生模型自己的分布中采样(on-policy 关键)
让学生在每个时间步都向教师模型靠拢(提供密集反馈信号)
学生模型学习的是"在我这种情况下,老师会怎么做"

2、实验结果

训练 Qwen3-8B-Base(学生模型)解决数学竞赛题(AIME’24)时,使用 Qwen3-32B 作为教师模型。

第一阶段:使用 OpenThoughts 数据集 https://www.modelscope.cn/datasets/open-thoughts/OpenThoughts3-1.2M 进行全参数 SFT,在训练 40 万条数据后,Qwen3-8B-Base 模型的 AIME’24 得分达到了 60 分。

第二阶段:基于该模型检查点,对比三种训练方法将模型分数提升到 70 分所需的成本:结论:On-Policy 蒸馏相比 RL 只需要约 十分之一的计算量。
实验发现即使只用一道题反复训练,On-Policy 蒸馏也能让学生模型学会解题能力,而不会简单地死记硬背答案。这是因为它学的是"分布"而不是"答案"。

方法 所需数据量/计算量
SFT(异策略) 根据训练趋势估算,约需 160 万条数据
RL(强化学习) 根据 Qwen3 技术报告,需要 17,920 GPU 小时
估算约等同于 200 万条 SFT 数据的训练成本
On-Policy 蒸馏 只需要 1,800 GPU 小时

3、相关挑战和解决方法

1、与 SFT 使用离线数据不同,On-Policy 蒸馏需要学生模型实时生成数据:将学生模型的在线采样交给 vLLM 处理,可以大幅缓解性能瓶颈。

2、在 On-Policy 蒸馏中,教师模型和学生模型扮演不同角色,资源需求差异显著:
教师模型:参数量大(如 32B),主要用于推理打分,需要更高的显存优化等级(如 Deepspeed ZeRO-3)
学生模型:参数量小(如 8B),需要频繁更新梯度,使用较低的优化等级(如 Deepspeed ZeRO-2)反而训练更快
结论:理想的方案是为教师和学生模型分别配置 DeepSpeed 策略:教师使用高等级显存优化保证能加载,学生使用低等级配置加速训练。

# On-Policy Distillation https://thinkingmachines.ai/blog/on-policy-distillation/

# CUDA_VISIBLE_DEVICES=7 \
# swift rollout \
#     --model Qwen/Qwen3-8B-Base \
#     --vllm_max_model_len 24192

NPROC_PER_NODE=7 \
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 \
swift rlhf \
    --rlhf_type gkd \
    --model Qwen/Qwen3-8B-Base \
    --teacher_model Qwen/Qwen3-32B \
    --train_type full \
    --dataset open-thoughts/OpenThoughts3-1.2M#10000 \
    --seq_kd false \
    --lmbda 1 \
    --beta 1 \
    --torch_dtype bfloat16 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --learning_rate 1e-5 \
    --gradient_accumulation_steps 1 \
    --save_steps 1000 \
    --save_total_limit 2 \
    --logging_steps 1 \
    --max_length 16000 \
    --max_completion_length 8192 \
    --output_dir output \
    --warmup_ratio 0.05 \
    --save_only_model true \
    --dataloader_num_workers 64 \
    --dataset_num_proc 4 \
    --deepspeed zero2 \
    --teacher_deepspeed zero3 \
    --attn_impl flash_attn \
    --use_vllm true \
    --vllm_mode server \
    --vllm_server_host 127.0.0.1 \
    --vllm_server_port 8000

Reference

[1] GKD(Generalized Knowledge Distillation,广义知识蒸馏)
[2] https://swift.readthedocs.io/zh-cn/latest/Instruction/GKD.html
[3] Thinking Machines Lab最新研究结果如何复现?On-Policy Distillation让训练成本直降10倍
[4] https://thinkingmachines.ai/blog/on-policy-distillation/
[5] https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd/full.sh

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐