Bili-Docs
技术工具AI 应用

耗时两天半,完全从零开始实现大模型知识蒸馏(Qwen2.5系列模型),从原理讲解、代码实现到效果测试,绝对让你搞懂模型蒸馏

视频深入讲解了大模型知识蒸馏的原理,并基于 Qwen2.5 系列模型演示了从代码实现到不同 KL 散度算法效果测试的全过程。

UP主: 偷星九月333 · 时长: 28:14 · 🔗 B站原视频

发布: 2025-01-17 · 收录: 2025-01-18

标签: 大模型 · 知识蒸馏 · Qwen2.5 · KL散度 · 深度学习

开场与主题:大模型知识蒸馏要讲清楚

Hello,大家好,我是九月。今天给大家分享一下大模型的知识蒸馏相关的一些知识。我们先来看一下它的大致原理,还有一些方法。

知识蒸馏的基本原理:教师模型与学生模型对齐

知识蒸馏就是有一个教师模型,还有一个学生模型。我们需要把教师模型的知识压缩到学生模型里面,然后让学生模型在教师模型的监督下不断优化。

这个优化过程,本质上就是对齐学生模型和教师模型输出的概率分布。这个概率分布的对齐通常通过 KL 散度来控制。它和传统 CV、NLP 里的知识蒸馏可能有一点区别。

大模型蒸馏的两种路线:黑盒与白盒

大模型的知识蒸馏主要有两种:

第一种是黑盒知识蒸馏。很多人都用过,就是用大模型生成一些数据,然后用这些数据去微调更小的模型,通过这种方式达到蒸馏目的。实现简单,但蒸馏效率不高。

第二种是白盒知识蒸馏,也是今天要讲的:对齐学生模型和教师模型输出的概率分布,也可以对齐中间层、隐藏层的分布。对齐主要依赖 KL 散度。

KL 散度的几种实现:前向、反向与偏向版本

KL 散度有几种不同实现。经常说的 KL 散度一般指前向 KL 散度,其他很多都是基于它的变体。

前向 KL 散度的问题:可能高估低概率区域

MiniLM 论文里提到,前向 KL 散度可能会让学生模型高估教师模型中概率比较低的部分。

从公式角度看:当 (p(x)) 增大时,要让整体最小,(q(x)) 也需要增大,这个没问题。但当 (p(x)) 趋于 0 时,整体大小主要受 (p(x)) 影响,几乎不管 (q(x)) 取什么值,这部分 KL 都会很小,导致对 (q) 的优化不充分,可能让学生模型去高估教师模型中低概率的位置。

从拟合图的直观感觉上,前向 KL 更倾向拟合多峰。比如 (p(x)) 是双峰时,(q(x)) 可能会拟合成两个峰的综合,有点像平均的样子。

反向 KL 散度:更偏向拟合单峰

MiniLM 提出了反向 KL 散度。它和前向的区别就是把公式里的 (p) 和 (q) 互换。

当 (p \to 0) 时,为了使 KL 小,(q) 也需要趋近 0。论文里说在大模型蒸馏场景下,反向 KL 优于前向 KL。

但也有其他论文的说法是,反向 KL 不一定比前向 KL 更优。我自己用数据集做蒸馏的结论是:前向 KL 比反向 KL 效果更好。所以这个可能和数据集有关,属于实验驱动的结论。

反向 KL 更倾向于拟合单个峰。

偏向前向/反向的 KL:加权综合

还有偏向前向 KL、偏向反向 KL,本质是把前向 KL 和反向 KL 做加权综合。

一种是把学生分布和教师分布加权后作为“学生分布”来算,另一种是加权后作为“教师分布”来算。实验里我得到的结论是效果也不好,可能也跟我数据集有关。

实验设置:教师/学生模型与数据量

下面讲测试过程,再看代码实现。

  • 教师模型:Qwen2.5 3B
  • 学生模型:Qwen2.5 0.5B
  • 我先在指定数据集上对教师模型微调:训练 500 条数据
  • 测试数据:1000 条
  • 教师模型微调后准确度:81.1%

教师微调完就开始做知识蒸馏。

三种蒸馏方案对比:只 KL、先微调学生、KL+CE

我主要探索了三种方案:

方案 1:不微调学生,只用前向 KL 对齐分布

不微调学生模型,只通过 KL 散度损失对齐学生模型和教师模型(KL 用前向 KL),蒸馏 2 个 epoch,准确度 73%。

这个效果比“直接微调学生模型”的效果还差一点,但还算可以。

方案 2:先微调学生,再加 KL 蒸馏(结果更差)

如果我直接在这个数据集上微调学生模型,不做蒸馏,准确度能到 80.3%。

但我拿这个微调后的学生模型,再加 KL 散度做蒸馏 2 个 epoch,准确度反而更低,和想象不太一样。

我原来以为教师模型已经在训练集上拟合过,分布比较接近数据分布;学生模型也微调过,分布也应该接近数据分布。这时候再用 KL 去对齐微调后的学生与微调后的教师,至少不应该下降这么多。

但我看到网上也有人做实验得出类似结论:对大模型来说,先微调学生再做知识蒸馏,效果可能更差。

方案 3:不微调学生,KL + 交叉熵(50%/50%)

第三种方案是不微调学生,损失由两部分组成:

  • KL 散度损失(学生分布 vs 教师分布)占 50%
  • 交叉熵损失(学生输出 vs 真实标签)占 50%

同样蒸馏 2 个 epoch,准确度 70.5%。在我这个数据集上,它没有“只用 KL”效果好。

所以我的结论是:在这个实验里,只使用 KL 散度效果最好。

KL 变种效果测试:反向 KL 很差,偏向版本也不行

在确认“只用 KL”最好后,我又测试了 KL 的变种,仍然只用 KL 作为损失:

  • 反向 KL:准确度只有 54%,降得很低
  • 偏向前向 KL:效果也很差,模型会不断重复输出,感觉把模型调坏了
  • 前向 KL:损失下降更正常。我之前蒸馏 1 个 epoch 准确度是 70.5%,再继续蒸馏 1 个 epoch 后准确度有所提升

损失曲线方面:

  • 反向 KL:前面阶段损失震荡,可能和初始学习率偏大有关,后面虽收敛但效果不好
  • 偏向前向 KL:收敛存在问题,损失下降异常
  • 前向 KL:下降相对正常

这里我没有针对不同 KL 变种去调超参,因为训练一次耗时不短,所以超参全部保持一致。反向 KL 和偏向前向 KL 可能需要重新调超参才能得到更正常的效果。

代码实现思路:核心就是改 Trainer 的损失函数

原理和实验大致就这些,下面看代码实现。实现并不难,主要就是改损失函数。

我这次只保留了蒸馏 2 个 epoch 的结果,因为后面在这个基础上继续训练,效果反而更差。

KL 散度计算:对照公式实现前向 KL

实现用 Transformers 库很简单,需要加载教师模型(LoRA 微调后)和学生模型。

以“前向 KL”为例:

  • (p(x)):教师模型概率分布(logits)
  • (q(x)):学生模型概率分布(logits)
  • 还会除以温度系数(知识蒸馏里很常见)

实现上会用 log_softmax 得到 log 概率,然后组合成 KL 的形式。最后要对最后一个维度求和,相当于做积分/求和。

另外还要做 mask:训练时只关注模型生成那部分的分布,输入部分和 padding 部分不需要,所以要把输入部分以及 padding 的 token mask 掉,只对输出部分计算 KL。

反向 KL、偏向前向这些实现和前向 KL 差不多,就不展开了。

自定义 Trainer:教师模型不更新,只做前向推理

训练代码里我继承 Transformers 的 Trainer。

需要传入:

  • 学生模型
  • 教师模型
  • 一个参数控制是否加入交叉熵损失

核心是重写 compute_loss

  1. 输入数据先跑学生模型,得到输出
  2. 教师模型不参与更新,用 torch.no_grad(),不反向传播
  3. 学生模型交叉熵损失:学生输出 vs 真实标签
  4. 拿到学生 logits 和教师 logits,计算 KL 损失
  5. 如果需要 KL + 交叉熵,就加权求和;否则直接用 KL 作为最终 loss 返回

教师/学生输出维度不一致的问题:padding 或截断

我一开始用 Qwen2.5 3B 蒸馏 0.5B,两者输出维度一样,不需要处理。

但如果用 Qwen2.5 7B 去蒸馏 0.5B,或者用 14B、32B 去蒸馏 0.5B/3B,它们 logits 维度可能不一样,需要处理。

两种方法:

  1. 对学生 logits 做 padding,填充到和教师一样的形状
  2. 对教师 logits 做截断,截断到和学生一样的形状

padding 用 -100,是因为 Transformers 里默认 -100 不计算损失,这里也沿用这个逻辑。

LoRA 加载:学生加 LoRA,教师加载微调后的 LoRA 权重

学生模型用 LoRA 微调方式:加载 LoRA 配置,然后得到加入 LoRA 的模型。

教师模型是微调后的,我没有把 LoRA 合并到 base model,而是直接用 LoRA 加载方式,把 LoRA 权重和原始权重一起加载进模型。

后面就是训练参数、数据加载,开始蒸馏。

数据集格式与 label 处理:只算答案部分,prompt 用 -100 mask

数据集处理是标准 SFT 格式:

输入由 prompt 和答案组成。标签只需要计算模型生成那部分的损失,prompt 那部分不需要算损失,所以用 -100 把这部分填成 -100。

batch 里需要长度一致:

  • 超过最大长度就截断
  • 小于最大长度就 padding
  • padding 部分同样不计算损失,也用 -100 填充

小结:我的数据集上蒸馏不如直接微调,但不代表通用结论

整体流程就是这样。这次尝试的蒸馏方法只对齐了两个模型输出的概率分布,没有尝试对齐中间层。以后有机会会尝试其他蒸馏方法。

从我自己的数据集来看,知识蒸馏效果不如直接微调,但这可能和数据集本身有关,并不是绝对结论。有兴趣的小伙伴可以用自己的数据集测试。

数据集只要处理成 LLaMA-Factory 微调那种格式就行,三个字段:instructioninputoutput

我这个测试数据集不方便公开,大家可以用自己的数据集做测试。

结束语

OK,今天的视频就先到这里。喜欢的小伙伴可以点个关注,也欢迎在评论区交流和讨论,拜拜。

On this page