【研2基本功 Score-based Diffusion 1】手搓Diffusion SDE,数学is all you need
深入讲解基于SDE的Score-based扩散模型理论,涵盖数学推导、得分函数及代码实现。
UP主: happy魇 · 时长: 22:53 · 🔗 B站原视频
发布: 2024-10-17 · 收录: 2025-01-11
标签: 扩散模型 · 深度学习 · SDE · 数学推导 · AI算法
代码链接与近况
今天要分享的代码我已经 debug 好了,把一些错误修正之后放到飞书文档里。飞书文档链接我放在视频下方简介里,大家可以点进去把整套代码自己跑一遍、debug 走一遍。
Hello,大家早上好,好久不见。今天早上更新一个视频。前段时间 NeurIPS 结果也出了:中了一篇,还有一篇均分还算不错,但还是被拒了。
Score-based Diffusion:朗之万动力学 vs SDE
今天我们来讲一个 Score-based 版本 Diffusion Model 的代码实现。Score-based 版本的 Diffusion Model 大体分为两种:朗之万动力学和随机微分方程(SDE)。
今天的关注点主要放在随机微分方程上。我们会使用宋飏博士的博客来看理论推导,然后用一个均值回归 SDE 来实现它的前向和反向过程 demo,以及它的损失函数。整体过程比较简单,我们开始。
SDE 的基本形式:连续时间加噪到高斯
先简单看一下 SDE。SDE 是一个连续时间的随机过程,相当于把原始图像/数据通过随机微分方程一直加噪,直到它变成一个高斯分布。
在宋飏博士博客里,SDE 具有类似这样的形式:前面是漂移项 (f),后面是扩散项 (g)。博客里也提到,SDE 的漂移项 (f) 和扩散项 (g) 都可以人工设置,也就是说 SDE 给 Diffusion 的设计提供了很大的空间。
反向过程与得分函数(score)
再看 SDE 的反向过程。数学上,对于刚才那种形式的 SDE,它存在一个反向的过程,形式大概是:漂移项里会出现一个很关键的量 (\nabla_x \log p_t(x)),也就是 (\log p_t(x)) 对 (x) 的偏导。
我们把这个东西称为得分函数(score function),也就是 score-based 方法里最核心的 score。只有知道得分函数,才能做反向的去噪过程。
Score-based 的损失函数和 DDPM 的损失到底怎么对应?
宋飏博士在博客里直接给出了一个数学表达式。我这里稍微展开讲一下,让大家更容易把它和 DDPM 的损失联系起来理解。
回顾 DDPM:从 VLB 到噪声预测的 MSE
先回顾一下之前讲过的 DDPM。DDPM 的目标是优化变分下界(VLB)。优化这个核心目标,最后可以变成:在已知 (x_0) 的情况下,优化从 (x_t) 生成 (x_{t-1}) 的概率。
用 Bayes 公式展开,再经过一系列推导,可以得到这个条件概率服从高斯分布:均值是某个 (\mu),方差是某个 (\sigma^2)。
得到高斯形式之后,训练时本质上就是让模型去拟合这个均值。这里 (x_0) 可以用 (x_t) 来表示:因为前向过程里我们写过 [ x_t = \alpha_t x_0 + \sigma_t \epsilon,\quad \epsilon \sim \mathcal{N}(0, I) ] 所以把它对 (x_0) 反解一下就行。
这样我们就能把均值的“真值”(label)写成含有 (\epsilon) 的形式。然后把其中的 (\epsilon) 用可学习的神经网络来代替,本质上就是做模仿学习:预测噪声 (\epsilon_\theta) 和真实噪声 (\epsilon) 做差,再做 MSE Loss。最后就得到 DDPM 常见的噪声预测损失形式。
得分函数的损失长什么样?
那 score 和这个损失有什么关系?score 的目标一般长得很像: [ \left|\nabla_x \log p_t(x) - s_\theta(x,t)\right|^2 ] 也就是 (\log p) 对 (x) 的偏导,减去神经网络学出来的 score,再做二范数(MSE)。
问题是:这个是怎么和 DDPM 那套噪声预测联系起来的?
Tweedie 公式:把均值估计和 (\nabla_x \log p) 连起来
这里要用到一个叫 Tweedie 的公式。它核心意思是:当你已知一个高斯变量的方差,并且你有一个带噪观测 (z),你想估计它真实的均值 (\mu(z)),可以得到一个和 (\nabla_z \log p(z)) 有关的表达式。
利用这个公式,我们就能建立得分函数和 DDPM 经典公式之间的关系。
在 DDPM 里最经典的一个式子是从 (x_0) 推到 (x_t): [ x_t = \alpha_t x_0 + \sigma_t \epsilon ] 对于 (x_t) 的均值,我们也可以用 Tweedie 公式写成另一个表达:大概是 (x_t) 加上“方差”乘上 (\nabla_{x_t}\log p_t(x_t)) 这一类形式。把它变形,就能得到一个关于 (\nabla_{x_t}\log p_t(x_t)) 的式子。
这两个式子有个核心相同的部分:某一项和后面那一整项是对应得上的。把它们做等号,再化简,就能得到 (x_0) 可以用 score 来反表示。
回到前面:DDPM 里我们用 (\epsilon) 表示 (x_0) 的反表示;这里用 score 也能反表示 (x_0)。把它带回去、再作差,就能得到博客里的那一部分:score function 的 loss,本质上就是对 (\nabla_x \log p) 的估计做 MSE,也就是宋飏博士博客给出的公式。
均值回归 SDE(Mean-Reverting SDE)的建模思路
接下来我们看一个简单代码实现。以某个工作的 SDE 实现为例,它引入了所谓的均值回归 SDE,把 SDE 建模成这种形式:
- 漂移项是 (-\theta(t) x_t)
- 扩散项是 (\gamma(t))(也可以记成 (\sigma(t)))
- 并且指定关系:(\sigma(t)^2 = 2\theta(t))
在指定这个关系之后,可以推导出:从任意时刻的 (x_t) 都能得到另一个时刻的高斯分布表达式,并且可以把式子进一步简化成一个更紧凑的公式。
它的反向过程也可以类比写出来:把对应的 (f, g) 带进反向 SDE 的通式里直接写就行。
这类 SDE 还有一些性质,比如反向过程可以直接用高斯分布表达,而不一定要用反向 SDE 一步步推。但这里我们不 care 这个性质,我们只把注意力放在:前向 SDE 公式、反向 SDE 公式,以及怎么实现训练和采样。
代码实现:SDEBase 抽象类(标准写法)
先写一个标准的 SDE 实现。其他对 SDE 的变体,基本都可以基于标准 SDE 来扩展。
我们先构建一个 SDE base 的标准类,一般会用 abc 去写抽象类。
初始化输入主要有一个扩散步数 T。虽然 SDE 是连续时间微分方程,但实现上还是要离散化,所以我们设置 dt = 1 / T。
在 SDE 里最重要的两个东西:
- 漂移项
drift(x, t)(对应 (f)) - 扩散项
diffusion(x, t)(对应 (g))
因为不同 SDE 的实现方式不同,所以在抽象类里先 pass,你在继承这个类的时候必须把这两个方法实现出来。
前向过程:forward + forward_step
接下来写前向过程 forward。
前向过程输入是原始数据 (x_0)。然后不断迭代 (T) 步,把 (x_0) 更新为噪声更大的 (x_t)。每一步调用一个 forward_step。
forward_step 的输入是当前的 (x) 和当前时间步 (t)。
数学公式是: [ dx = f(x,t),dt + g(x,t),dw ]
照着实现就行:
dw是扩散噪声项,一般就是高斯噪声再乘上 (\sqrt{dt}) 这一类的处理(实现上按你定义的离散形式来)f和g是抽象函数,直接调用self.drift(x,t)和self.diffusion(x,t)
然后返回 x + dx。最终 forward 返回得到的 (x_T)。
反向采样:reverse(SDE 版 vs ODE 版)
再写反向过程 reverse。反向过程一般有两种:
- 反向 SDE
- 反向 ODE(probability flow ODE)
ODE 版和 SDE 版对比,区别主要在两点:
- ODE 版没有扩散项(没有随机噪声)
- 漂移项里 score 的系数会多一个 (1/2)
reverse 的输入与循环
反向过程输入一般需要:
- 初始噪声 (x_T)
score_function(神经网络)- 模式选择:用 SDE 还是 ODE
每一步先算 score:
score_value = score_function(x_t, t)
然后根据模式选择:
- SDE:
x = reverse_sde(x, t, score_value) - ODE:
x = reverse_ode(x, t, score_value)
循环 (T) 步后返回 (x_0) 的估计。
reverse_sde / reverse_ode 的实现要点
reverse_sde
反向 SDE 的形式是:漂移项会变成 [ f(x,t) - g(x,t)^2 \cdot \text{score}(x,t) ] 扩散项还是 (g(x,t),dw)。
实现时注意一个点:当 (t=0) 的时候不应该再加扩散噪声项(这个和 DDPM 一样),所以实现里要加一个判断 t > 0 再加 g * dw。
最后返回 x - dx(反向是去噪,方向相反)。
reverse_ode
ODE 版输入一样。唯一的区别:
- 没有扩散项
- 漂移项里 score 那部分系数变成 (0.5 \cdot g^2 \cdot \text{score})
算出 dx = (f - 0.5 * g^2 * score) * dt,然后同样返回 x - dx。
到这里,一个标准的 SDEBase 的前向和反向过程就写完了。