Bili-Docs
技术工具AI 应用

【研2基本功 Score-based Diffusion 1】手搓Diffusion SDE,数学is all you need

深入讲解基于SDE的Score-based扩散模型理论,涵盖数学推导、得分函数及代码实现。

UP主: happy魇 · 时长: 22:53 · 🔗 B站原视频

发布: 2024-10-17 · 收录: 2025-01-11

标签: 扩散模型 · 深度学习 · SDE · 数学推导 · AI算法

代码链接与近况

今天要分享的代码我已经 debug 好了,把一些错误修正之后放到飞书文档里。飞书文档链接我放在视频下方简介里,大家可以点进去把整套代码自己跑一遍、debug 走一遍。

Hello,大家早上好,好久不见。今天早上更新一个视频。前段时间 NeurIPS 结果也出了:中了一篇,还有一篇均分还算不错,但还是被拒了。

Score-based Diffusion:朗之万动力学 vs SDE

今天我们来讲一个 Score-based 版本 Diffusion Model 的代码实现。Score-based 版本的 Diffusion Model 大体分为两种:朗之万动力学和随机微分方程(SDE)。

今天的关注点主要放在随机微分方程上。我们会使用宋飏博士的博客来看理论推导,然后用一个均值回归 SDE 来实现它的前向和反向过程 demo,以及它的损失函数。整体过程比较简单,我们开始。

SDE 的基本形式:连续时间加噪到高斯

先简单看一下 SDE。SDE 是一个连续时间的随机过程,相当于把原始图像/数据通过随机微分方程一直加噪,直到它变成一个高斯分布。

在宋飏博士博客里,SDE 具有类似这样的形式:前面是漂移项 (f),后面是扩散项 (g)。博客里也提到,SDE 的漂移项 (f) 和扩散项 (g) 都可以人工设置,也就是说 SDE 给 Diffusion 的设计提供了很大的空间。

反向过程与得分函数(score)

再看 SDE 的反向过程。数学上,对于刚才那种形式的 SDE,它存在一个反向的过程,形式大概是:漂移项里会出现一个很关键的量 (\nabla_x \log p_t(x)),也就是 (\log p_t(x)) 对 (x) 的偏导。

我们把这个东西称为得分函数(score function),也就是 score-based 方法里最核心的 score。只有知道得分函数,才能做反向的去噪过程。

Score-based 的损失函数和 DDPM 的损失到底怎么对应?

宋飏博士在博客里直接给出了一个数学表达式。我这里稍微展开讲一下,让大家更容易把它和 DDPM 的损失联系起来理解。

回顾 DDPM:从 VLB 到噪声预测的 MSE

先回顾一下之前讲过的 DDPM。DDPM 的目标是优化变分下界(VLB)。优化这个核心目标,最后可以变成:在已知 (x_0) 的情况下,优化从 (x_t) 生成 (x_{t-1}) 的概率。

用 Bayes 公式展开,再经过一系列推导,可以得到这个条件概率服从高斯分布:均值是某个 (\mu),方差是某个 (\sigma^2)。

得到高斯形式之后,训练时本质上就是让模型去拟合这个均值。这里 (x_0) 可以用 (x_t) 来表示:因为前向过程里我们写过 [ x_t = \alpha_t x_0 + \sigma_t \epsilon,\quad \epsilon \sim \mathcal{N}(0, I) ] 所以把它对 (x_0) 反解一下就行。

这样我们就能把均值的“真值”(label)写成含有 (\epsilon) 的形式。然后把其中的 (\epsilon) 用可学习的神经网络来代替,本质上就是做模仿学习:预测噪声 (\epsilon_\theta) 和真实噪声 (\epsilon) 做差,再做 MSE Loss。最后就得到 DDPM 常见的噪声预测损失形式。

得分函数的损失长什么样?

那 score 和这个损失有什么关系?score 的目标一般长得很像: [ \left|\nabla_x \log p_t(x) - s_\theta(x,t)\right|^2 ] 也就是 (\log p) 对 (x) 的偏导,减去神经网络学出来的 score,再做二范数(MSE)。

问题是:这个是怎么和 DDPM 那套噪声预测联系起来的?

Tweedie 公式:把均值估计和 (\nabla_x \log p) 连起来

这里要用到一个叫 Tweedie 的公式。它核心意思是:当你已知一个高斯变量的方差,并且你有一个带噪观测 (z),你想估计它真实的均值 (\mu(z)),可以得到一个和 (\nabla_z \log p(z)) 有关的表达式。

利用这个公式,我们就能建立得分函数和 DDPM 经典公式之间的关系。

在 DDPM 里最经典的一个式子是从 (x_0) 推到 (x_t): [ x_t = \alpha_t x_0 + \sigma_t \epsilon ] 对于 (x_t) 的均值,我们也可以用 Tweedie 公式写成另一个表达:大概是 (x_t) 加上“方差”乘上 (\nabla_{x_t}\log p_t(x_t)) 这一类形式。把它变形,就能得到一个关于 (\nabla_{x_t}\log p_t(x_t)) 的式子。

这两个式子有个核心相同的部分:某一项和后面那一整项是对应得上的。把它们做等号,再化简,就能得到 (x_0) 可以用 score 来反表示。

回到前面:DDPM 里我们用 (\epsilon) 表示 (x_0) 的反表示;这里用 score 也能反表示 (x_0)。把它带回去、再作差,就能得到博客里的那一部分:score function 的 loss,本质上就是对 (\nabla_x \log p) 的估计做 MSE,也就是宋飏博士博客给出的公式。

均值回归 SDE(Mean-Reverting SDE)的建模思路

接下来我们看一个简单代码实现。以某个工作的 SDE 实现为例,它引入了所谓的均值回归 SDE,把 SDE 建模成这种形式:

  • 漂移项是 (-\theta(t) x_t)
  • 扩散项是 (\gamma(t))(也可以记成 (\sigma(t)))
  • 并且指定关系:(\sigma(t)^2 = 2\theta(t))

在指定这个关系之后,可以推导出:从任意时刻的 (x_t) 都能得到另一个时刻的高斯分布表达式,并且可以把式子进一步简化成一个更紧凑的公式。

它的反向过程也可以类比写出来:把对应的 (f, g) 带进反向 SDE 的通式里直接写就行。

这类 SDE 还有一些性质,比如反向过程可以直接用高斯分布表达,而不一定要用反向 SDE 一步步推。但这里我们不 care 这个性质,我们只把注意力放在:前向 SDE 公式、反向 SDE 公式,以及怎么实现训练和采样。

代码实现:SDEBase 抽象类(标准写法)

先写一个标准的 SDE 实现。其他对 SDE 的变体,基本都可以基于标准 SDE 来扩展。

我们先构建一个 SDE base 的标准类,一般会用 abc 去写抽象类。

初始化输入主要有一个扩散步数 T。虽然 SDE 是连续时间微分方程,但实现上还是要离散化,所以我们设置 dt = 1 / T

在 SDE 里最重要的两个东西:

  1. 漂移项 drift(x, t)(对应 (f))
  2. 扩散项 diffusion(x, t)(对应 (g))

因为不同 SDE 的实现方式不同,所以在抽象类里先 pass,你在继承这个类的时候必须把这两个方法实现出来。

前向过程:forward + forward_step

接下来写前向过程 forward

前向过程输入是原始数据 (x_0)。然后不断迭代 (T) 步,把 (x_0) 更新为噪声更大的 (x_t)。每一步调用一个 forward_step

forward_step 的输入是当前的 (x) 和当前时间步 (t)。

数学公式是: [ dx = f(x,t),dt + g(x,t),dw ]

照着实现就行:

  • dw 是扩散噪声项,一般就是高斯噪声再乘上 (\sqrt{dt}) 这一类的处理(实现上按你定义的离散形式来)
  • fg 是抽象函数,直接调用 self.drift(x,t)self.diffusion(x,t)

然后返回 x + dx。最终 forward 返回得到的 (x_T)。

反向采样:reverse(SDE 版 vs ODE 版)

再写反向过程 reverse。反向过程一般有两种:

  • 反向 SDE
  • 反向 ODE(probability flow ODE)

ODE 版和 SDE 版对比,区别主要在两点:

  1. ODE 版没有扩散项(没有随机噪声)
  2. 漂移项里 score 的系数会多一个 (1/2)

reverse 的输入与循环

反向过程输入一般需要:

  • 初始噪声 (x_T)
  • score_function(神经网络)
  • 模式选择:用 SDE 还是 ODE

每一步先算 score: score_value = score_function(x_t, t)

然后根据模式选择:

  • SDE:x = reverse_sde(x, t, score_value)
  • ODE:x = reverse_ode(x, t, score_value)

循环 (T) 步后返回 (x_0) 的估计。

reverse_sde / reverse_ode 的实现要点

reverse_sde

反向 SDE 的形式是:漂移项会变成 [ f(x,t) - g(x,t)^2 \cdot \text{score}(x,t) ] 扩散项还是 (g(x,t),dw)。

实现时注意一个点:当 (t=0) 的时候不应该再加扩散噪声项(这个和 DDPM 一样),所以实现里要加一个判断 t > 0 再加 g * dw

最后返回 x - dx(反向是去噪,方向相反)。

reverse_ode

ODE 版输入一样。唯一的区别:

  • 没有扩散项
  • 漂移项里 score 那部分系数变成 (0.5 \cdot g^2 \cdot \text{score})

算出 dx = (f - 0.5 * g^2 * score) * dt,然后同样返回 x - dx

到这里,一个标准的 SDEBase 的前向和反向过程就写完了。

On this page