【KD】Generalized Knowledge Distillation和On-Policy Distillation
传统 KD:教师模型生成一些训练输出,生模型模仿这些输出。GKD(广义 KD):学生模型先自己生成一些序列,然后用教师模型对这些学生生成的序列进行打分或提供反馈,学生模型基于这些反馈进一步调整自己。这可以更好解决“训练时只学教师输出,而测试时要靠自己生成输出”之间的分布不一致问题。GKD = 用 teacher 的“软分布”监督 student,但 teacher 的数据来源可以混合,一部分来自真
note
- 传统 KD:教师模型生成一些训练输出,生模型模仿这些输出。
- GKD(广义 KD):
- 学生模型先自己生成一些序列,然后用教师模型对这些学生生成的序列进行打分或提供反馈,学生模型基于这些反馈进一步调整自己。
- 这可以更好解决“训练时只学教师输出,而测试时要靠自己生成输出”之间的分布不一致问题。
- GKD = 用 teacher 的“软分布”监督 student,但 teacher 的数据来源可以混合,一部分来自真实数据,一部分来自student自己采样的数据。
- 为什么 RL 探索成本高,而蒸馏可以避免?
- 强化学习(RL) 之所以训练成本高,是因为它需要在庞大的"策略空间"中通过大量随机尝试和试错来摸索有效策略,这个过程消耗了绝大多数计算资源。
- On-Policy 蒸馏 则完全不同——教师模型已经掌握了正确的策略,能够直接向学生模型展示"标准答案",让学生无需盲目探索,只需模仿学习即可。
文章目录
一、Generalized Knowledge Distillation
GKD(Generalized Knowledge Distillation,广义知识蒸馏)训练算法由论文 On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes 提出。该算法通过结合离线(off-policy)和在线(on-policy)学习策略,将教师模型的知识迁移到学生模型中。
1、损失函数

- 给定techer模型、student模型、SFT数据
- 超参设置:λ比例、Divergence D、学习率
- 走老师路线(u < λ):从输入X中抽一个prompt,学生先自己生成一个答案,组成batch B
- 走真实数据路线(u ≥ λ):就是普通的SFT
- 核心:θ ← θ − η * ∇θ D(π_T || π_S^θ)(y|x)
- 比较老师和学生在 (x, y) 上的输出分布
当给定输入序列 x x x 与输出序列 y y y,GKD 的损失函数可以写为:
L GKD ( x , y ) = ∑ t = 1 ∣ y ∣ D ( P teacher ( ⋅ ∣ x , y < t ) , P student ( ⋅ ∣ x , y < t ) ) \mathcal{L}_{\text{GKD}}(x, y) = \sum_{t=1}^{|y|} D(P_{\text{teacher}}(\cdot | x, y_{<t}), P_{\text{student}}(\cdot | x, y_{<t})) LGKD(x,y)=t=1∑∣y∣D(Pteacher(⋅∣x,y<t),Pstudent(⋅∣x,y<t))
其中:
- y < t = ( y 1 , y 2 , … , y t − 1 ) y_{<t} = (y_1, y_2, \ldots, y_{t-1}) y<t=(y1,y2,…,yt−1):前 t − 1 t-1 t−1 个 token 的序列
- P teacher ( ⋅ ∣ x , y < t ) P_{\text{teacher}}(\cdot | x, y_{<t}) Pteacher(⋅∣x,y<t):教师模型在给定上下文 x , y < t x, y_{<t} x,y<t 时的输出概率分布
- P student ( ⋅ ∣ x , y < t ) P_{\text{student}}(\cdot | x, y_{<t}) Pstudent(⋅∣x,y<t):学生模型在给定上下文 x , y < t x, y_{<t} x,y<t 时的输出概率分布
- D ( ⋅ , ⋅ ) D(\cdot, \cdot) D(⋅,⋅):散度函数,用于度量两个概率分布之间的差异性
2、散度度量函数
(1)KL 散度(Kullback-Leibler Divergence)
KL 散度是衡量两个概率分布 P P P 和 Q Q Q 之间差异的非对称度量:
KL ( P ∥ Q ) = ∑ v P ( v ) log P ( v ) Q ( v ) = E v ∼ P [ log P ( v ) Q ( v ) ] \text{KL}(P \| Q) = \sum_v P(v) \log \frac{P(v)}{Q(v)} = \mathbb{E}_{v \sim P}\left[\log \frac{P(v)}{Q(v)}\right] KL(P∥Q)=v∑P(v)logQ(v)P(v)=Ev∼P[logQ(v)P(v)]
(2)Forward KL 与 Reverse KL
在知识蒸馏中,根据 KL 散度中两个分布的顺序不同,有两种选择:
Forward KL(前向 KL)
KL ( P student ∥ P teacher ) = ∑ v P student ( v ) log P student ( v ) P teacher ( v ) \text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)} KL(Pstudent∥Pteacher)=v∑Pstudent(v)logPteacher(v)Pstudent(v)
特性:Mode-seeking(寻模)
- 期望在学生分布下计算
- 学生模型倾向于集中在教师模型的峰值区域(高概率区域)
Reverse KL(反向 KL)
KL ( P teacher ∥ P student ) = ∑ v P teacher ( v ) log P teacher ( v ) P student ( v ) \text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)} KL(Pteacher∥Pstudent)=v∑Pteacher(v)logPstudent(v)Pteacher(v)
特性:Mode-covering(覆模)
- 期望在教师分布下计算
- 学生模型倾向于覆盖教师的整个分布(包括低概率区域)
(3)广义 Jensen-Shannon 散度(Generalized JSD)
GKD 使用广义 JSD 作为核心度量,通过参数 β ∈ [ 0 , 1 ] \beta \in [0, 1] β∈[0,1] 在 Forward KL 和 Reverse KL 之间进行平滑插值。
对于两个概率分布 P P P 和 Q Q Q,广义 JSD 定义为:
D JSD ( β ) ( P , Q ) = β ⋅ KL ( P ∥ M ) + ( 1 − β ) ⋅ KL ( Q ∥ M ) D_{\text{JSD}(\beta)}(P, Q) = \beta \cdot \text{KL}(P \| M) + (1-\beta) \cdot \text{KL}(Q \| M) DJSD(β)(P,Q)=β⋅KL(P∥M)+(1−β)⋅KL(Q∥M)
其中混合分布 M M M 定义为:
M = β ⋅ P + ( 1 − β ) ⋅ Q M = \beta \cdot P + (1-\beta) \cdot Q M=β⋅P+(1−β)⋅Q
- 当 β = 0.5 \beta = 0.5 β=0.5 时,退化为标准的对称 JSD
- 通过调节 β \beta β,可以在 Mode-seeking 和 Mode-covering 之间权衡
在 GKD 中,我们令 P = P teacher P = P_{\text{teacher}} P=Pteacher, Q = P student Q = P_{\text{student}} Q=Pstudent,因此:
D JSD ( β ) ( P teacher , P student ) = β ⋅ KL ( P teacher ∥ M ) + ( 1 − β ) ⋅ KL ( P student ∥ M ) D_{\text{JSD}(\beta)}(P_{\text{teacher}}, P_{\text{student}}) = \beta \cdot \text{KL}(P_{\text{teacher}} \| M) + (1-\beta) \cdot \text{KL}(P_{\text{student}} \| M) DJSD(β)(Pteacher,Pstudent)=β⋅KL(Pteacher∥M)+(1−β)⋅KL(Pstudent∥M)
其中 M = β ⋅ P teacher + ( 1 − β ) ⋅ P student M = \beta \cdot P_{\text{teacher}} + (1-\beta) \cdot P_{\text{student}} M=β⋅Pteacher+(1−β)⋅Pstudent
对极端情况( β = 0 \beta = 0 β=0 或 β = 1 \beta = 1 β=1),直接计算单个 KL 散度:
- 当 β = 0 \beta = 0 β=0 时:直接定义 D = KL ( P teacher ∥ P student ) D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}}) D=KL(Pteacher∥Pstudent)(Reverse KL,Mode-covering)
- 当 β = 1 \beta = 1 β=1 时:直接定义 D = KL ( P student ∥ P teacher ) D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}}) D=KL(Pstudent∥Pteacher)(Forward KL,Mode-seeking)
- 当 0 < β < 1 0 < \beta < 1 0<β<1 时:使用上述混合分布公式进行插值
通过调节 β \beta β 参数,可以在不同的散度度量之间进行插值,当 β = 0.5 \beta = 0.5 β=0.5 时,散度为标准的对称 JSD。
3、三种训练模式
GKD训练具有三种训练模式,区别在于输出序列 y y y 的来源。
模式选择逻辑
训练时,每个样本按照以下优先级选择模式:
# 伪代码:模式选择逻辑
if random() < lmbda:
# Mode 1: On-Policy 学习,由学生模型采样输出序列
y = student.generate(x)
source = "student"
elif seq_kd:
# Mode 2: Sequential KD,由教师模型采样输出序列
y = teacher.generate(x)
source = "teacher"
else:
# Mode 3: Off-Policy 学习,使用数据集中的输出序列
y = y_ground_truth
source = "dataset"
# 相同的损失函数
loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))
Mode 1: On-Policy 学习
设置参数lambda, 以概率 λ \lambda λ 触发,使用学生模型采样 y ∼ P student ( ⋅ ∣ x ) y \sim P_{\text{student}}(\cdot | x) y∼Pstudent(⋅∣x)
- 学生模型从自己生成的序列中学习
- 暴露在自己可能犯的错误中,学会自我纠正和错误恢复
- 对齐训练分布与推理分布
- 提升模型的鲁棒性和实际应用表现
适用场景:
- 学生模型已有一定生成能力
- 希望提升模型在真实推理场景下的表现
Mode 2: Sequential KD(seq_kd=True 且未触发 on-policy)
设置参数 seq_kd=True, 当未触发 on-policy 时,使用教师模型采样
数据来源: y ∼ P teacher ( ⋅ ∣ x ) y \sim P_{\text{teacher}}(\cdot | x) y∼Pteacher(⋅∣x)
Mode 3: Off-Policy 学习(其他情况)
数据来源: y = y ∗ ∼ Dataset y = y^* \sim \text{Dataset} y=y∗∼Dataset
- 学生模型从数据集的标注序列中学习
4、参数设置
可以通过设置以下参数进行 GKD 训练:
| 参数 | 类型 | 默认值 | 取值范围 | 说明 |
|---|---|---|---|---|
--teacher_model |
str | 必需 | - | 教师模型路径或模型 ID |
--beta |
float | 0.5 | [0.0, 1.0] | 散度插值系数 • 0.0: Reverse KL (覆模,更多样) • 0.5: JSD (平衡,推荐) • 1.0: Forward KL (寻模,更专注) |
--lmbda |
float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率 • 0.0: 纯 Off-Policy • 0.5: 混合策略 (推荐) • 1.0: 纯 On-Policy |
--seq_kd |
bool | False | True/False | 是否使用教师生成序列 • False: 非 on-policy 时使用数据集 • True: 非 on-policy 时使用教师生成 |
--temperature |
float | 0.9 | > 0 | 生成采样温度,控制随机性 |
--max_completion_length |
int | 512 | > 0 | 生成时的最大 token 数 |
5、采样加速
在 GKD 训练中,涉及到两种在线采样的情况:
- 学生模型采样(当
lmbda > 0):以 λ \lambda λ 概率触发学生模型采样 - 教师模型采样(当
seq_kd=True):以 1 − λ 1-\lambda 1−λ 概率触发教师模型采样
由于采样过程会显著减慢训练速度,可参考以下两种加速方案:
方案 1:学生模型采样加速
要求:swift >= 3.10.dev
使用 vLLM 作为推理后端来加速学生模型采样,支持两种部署模式,与 GRPO 一致,参考GRPO文档, 相关参数参考GRPO vLLM 参数
注意:vLLM 加速仅适用于学生模型的 on-policy 采样(
lmbda > 0)。教师模型的 sequential KD 采样(seq_kd=True)目前仍使用 PyTorch,建议使用预采样方案。
训练脚本参考这里
# 4 * 45GiB, 10.29s/it
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
MASTER_PORT=29501 \
NPROC_PER_NODE=4 \
swift rlhf \
--rlhf_type gkd \
--model OpenGVLab/InternVL3-2B-Pretrained \
--teacher_model OpenGVLab/InternVL3-8B \
--dataset 'modelscope/coco_2014_caption:validation#2000' \
--load_from_cache_file true \
--split_dataset_ratio 0.01 \
--train_type full \
--seq_kd true \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--learning_rate 1e-5 \
--freeze_vit true \
--gradient_accumulation_steps 1 \
--eval_steps 50 \
--save_steps 50 \
--save_total_limit 2 \
--deepspeed zero2 \
--attn_impl flash_attn \
--logging_steps 5 \
--max_length 4096 \
--max_completion_length 512 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--dataset_num_proc 4 \
--save_only_model true
方案 2:教师模型预采样
对于教师模型采样(seq_kd=True),推荐使用 预采样 方式:先用教师模型离线生成高质量数据,再进行训练。
步骤 1:使用教师模型生成数据
export teacher_model='OpenGVLab/InternVL3-8B'
NPROC_PER_NODE=4 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
swift infer \
--model $teacher_model \
--infer_backend vllm \
--val_dataset 'modelscope/coco_2014_caption:validation#5000' \
--vllm_gpu_memory_utilization 0.9 \
--vllm_max_model_len 8192 \
--max_new_tokens 2048 \
--write_batch_size 1000 \
--result_path teacher_generated_data.jsonl
步骤 2:使用预生成数据训练
swift rlhf \
--rlhf_type gkd \
--model OpenGVLab/InternVL3-2B-Pretrained \
--teacher_model $teacher_model \
--dataset 'teacher_generated_data.jsonl' \
--seq_kd false \
...
训练脚本参考这里
二、On-Policy Distillation
1、训练方法
Off-Policy(离策略):向别人学习
On-Policy(同策略):从自己的经验学习
Thinking Machines Lab 提出的 On-Policy Distillation(同策略蒸馏):保持 On-Policy 的优势(学生在自己的状态中学习),但用 Off-Policy 的方式提供密集反馈(教师对每一步打分),即让学生模型自己做题(on-policy),但请教师模型对每一步都打分(dense reward)。
具体而言:
1、由学生自己生成答案:用学生模型(如 Qwen3-8B)解数学题,生成完整的推理过程
2、教师逐步评分:用强大的教师模型(如 Qwen3-32B)对学生的每个动作(token)进行评分
3、精准学习纠错:学生模型知道自己哪一步对、哪一步错,针对性改进
On-Policy 蒸馏使用反向 KL 散度作为损失函数:
K L ( π θ ∥ π teacher ) = E x ∼ π θ [ log π θ ( x t + 1 ∣ x 1 , . t ) − log π teacher ( x t + 1 ∣ x 1 , . t ) ] \mathrm{KL}\left(\pi_\theta \| \pi_{\text {teacher }}\right)=\mathbb{E}_{x \sim \pi_\theta}\left[\log \pi_\theta\left(x_{t+1} \mid x_{1, . t}\right)-\log \pi_{\text {teacher }}\left(x_{t+1} \mid x_{1, . t}\right)\right] KL(πθ∥πteacher )=Ex∼πθ[logπθ(xt+1∣x1,.t)−logπteacher (xt+1∣x1,.t)]
注意事项:
从学生模型自己的分布中采样(on-policy 关键)
让学生在每个时间步都向教师模型靠拢(提供密集反馈信号)
学生模型学习的是"在我这种情况下,老师会怎么做"
2、实验结果
训练 Qwen3-8B-Base(学生模型)解决数学竞赛题(AIME’24)时,使用 Qwen3-32B 作为教师模型。
第一阶段:使用 OpenThoughts 数据集 https://www.modelscope.cn/datasets/open-thoughts/OpenThoughts3-1.2M 进行全参数 SFT,在训练 40 万条数据后,Qwen3-8B-Base 模型的 AIME’24 得分达到了 60 分。
第二阶段:基于该模型检查点,对比三种训练方法将模型分数提升到 70 分所需的成本:结论:On-Policy 蒸馏相比 RL 只需要约 十分之一的计算量。
实验发现即使只用一道题反复训练,On-Policy 蒸馏也能让学生模型学会解题能力,而不会简单地死记硬背答案。这是因为它学的是"分布"而不是"答案"。
| 方法 | 所需数据量/计算量 |
|---|---|
| SFT(异策略) | 根据训练趋势估算,约需 160 万条数据 |
| RL(强化学习) | 根据 Qwen3 技术报告,需要 17,920 GPU 小时 估算约等同于 200 万条 SFT 数据的训练成本 |
| On-Policy 蒸馏 | 只需要 1,800 GPU 小时 |
3、相关挑战和解决方法
1、与 SFT 使用离线数据不同,On-Policy 蒸馏需要学生模型实时生成数据:将学生模型的在线采样交给 vLLM 处理,可以大幅缓解性能瓶颈。
2、在 On-Policy 蒸馏中,教师模型和学生模型扮演不同角色,资源需求差异显著:
教师模型:参数量大(如 32B),主要用于推理打分,需要更高的显存优化等级(如 Deepspeed ZeRO-3)
学生模型:参数量小(如 8B),需要频繁更新梯度,使用较低的优化等级(如 Deepspeed ZeRO-2)反而训练更快
结论:理想的方案是为教师和学生模型分别配置 DeepSpeed 策略:教师使用高等级显存优化保证能加载,学生使用低等级配置加速训练。
# On-Policy Distillation https://thinkingmachines.ai/blog/on-policy-distillation/
# CUDA_VISIBLE_DEVICES=7 \
# swift rollout \
# --model Qwen/Qwen3-8B-Base \
# --vllm_max_model_len 24192
NPROC_PER_NODE=7 \
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 \
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen3-8B-Base \
--teacher_model Qwen/Qwen3-32B \
--train_type full \
--dataset open-thoughts/OpenThoughts3-1.2M#10000 \
--seq_kd false \
--lmbda 1 \
--beta 1 \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 1 \
--save_steps 1000 \
--save_total_limit 2 \
--logging_steps 1 \
--max_length 16000 \
--max_completion_length 8192 \
--output_dir output \
--warmup_ratio 0.05 \
--save_only_model true \
--dataloader_num_workers 64 \
--dataset_num_proc 4 \
--deepspeed zero2 \
--teacher_deepspeed zero3 \
--attn_impl flash_attn \
--use_vllm true \
--vllm_mode server \
--vllm_server_host 127.0.0.1 \
--vllm_server_port 8000
Reference
[1] GKD(Generalized Knowledge Distillation,广义知识蒸馏)
[2] https://swift.readthedocs.io/zh-cn/latest/Instruction/GKD.html
[3] Thinking Machines Lab最新研究结果如何复现?On-Policy Distillation让训练成本直降10倍
[4] https://thinkingmachines.ai/blog/on-policy-distillation/
[5] https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/gkd/full.sh
更多推荐


所有评论(0)