扩散模型求解复杂反应(扩散模型 (Diffusion Model) 简要介绍与源码分析)
扩散模型 (Diffusion Model) 简要介绍与源码分析
前言
近期同事分享了 Diffusion Model, 这才发现生成模型的发展已经到了如此惊人的地步, OpenAI 推出的 Dall-E 2 可以根据文本描述生成极为逼真的图像, 质量之高直让人惊呼哇塞. 今早公众号给我推送了一篇关于 Stability AI 公司的报道, 他们推出的 AI 文生图扩散模型 Stable Diffusion 已开源, 能够在消费级显卡上实现 Dall-E 2 级别的图像生成, 效率提升了 30 倍.
于是找到他们的开源产品体验了一把, 在线体验地址在 https://huggingface.co/spaces/stabilityai/stable-diffusion (开源代码在 Github 上: https://github.com/CompVis/stable-diffusion), 在搜索框中输入 “A dog flying in the sky ” (一只狗在天空飞翔), 生成效果如下:
Amazing! 当然, 不是每一张图片都符合预期, 但好在可以生成无数张图片, 其中总有效果好的. 在震惊之余, 不免对 Diffusion Model (扩散模型) 背后的原理感兴趣, 就想看看是怎么实现的.
当时同事分享时, PPT 上那一堆堆公式扑面而来, 把我给整懵圈了, 但还是得撑起下巴, 表现出似有所悟 、深以为然的样子, 在讲到关键处不由暗暗点头以表示理解和赞许. 后面花了个周末专门学习了一下, 公式推导+代码分析, 感觉终于了解了基本概念, 于是记录下来形成此文, 不敢说自己完全懂了, 毕竟我不做这个方向, 但回过头去看 PPT 上的公式就不再发怵了.
广而告之
可以在微信中搜索 “珍妮的算法之路 ” 或者 “world4458 ” 关注我的微信公众号, 可以及时获取最新原创技术文章更新.
另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.
总览
本文对 Diffusion Model 扩散模型的原理进行简要介绍, 然后对源码进行分析. 扩散模型的实现有多种形式, 本文关注的是 DDPM (denoising diffusion probabilistic models). 在介绍完基本原理后, 对作者释放的 Tensorflow 源码进行分析, 加深对各种公式的理解.
参考文章
在理解扩散模型的路上, 受到下面这些文章的启发, 强烈推荐阅读:
Lilian 的博客, 内容非常非常详实, 干货十足, 而且每篇文章都极其用心, 向大佬学习: What are Diffusion Models? ewrfcas 的知乎, 公式推导补充了更多的细节: 由浅入深了解Diffusion Model Lilian 的博客, 介绍变分自动编码器 VAE: From Autoencoder to Beta-VAE, Diffusion Model 需要从分布中随机采样样本, 该过程无法求导, 需要使用到 VAE 中介绍的重参数技巧. Denoising Diffusion Probabilistic Models 论文, 其 TF 源码位于: https://github.com/hojonathanho/diffusion, 源码介绍以该版本为主 PyTorch 的开源实现: https://github.com/lucidrains/denoising-diffusion-pytorch, 核心逻辑和上面 Tensorflow 版本是一致的, Stable Diffusion 参考的是 pytorch 版本的代码.扩散模型介绍
基本原理
Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程.
具体来说, 前向阶段在原始图像
x
\mathbf{x}_0
x0 上逐步增加噪声, 每一步得到的图像x
t
\mathbf{x}_t
xt 只和上一步的结果x
t
−
1
\mathbf{x}_{t - 1}
xt−1 相关, 直至第T
T
T 步的图像x
T
\mathbf{x}_T
xT 变为纯高斯噪声. 前向阶段图示如下:而逆向阶段则是不断去除噪声的过程, 首先给定高斯噪声
x
T
\mathbf{x}_T
xT, 通过逐步去噪, 直至最终将原图像x
\mathbf{x}_0
x0 给恢复出来, 逆向阶段图示如下:模型训练完成后, 只要给定高斯随机噪声, 就可以生成一张从未见过的图像. 下面分别介绍前向阶段和逆向阶段, 只列出重要公式,
前向阶段
由于前向过程中图像
x
t
\mathbf{x}_t
xt 只和上一时刻的x
t
−
1
\mathbf{x}_{t - 1}
xt−1 有关, 该过程可以视为马尔科夫过程, 满足:q
(
x
1
:
T
∣
x
)
=
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
q
(
x
t
∣
x
t
−
1
)
=
N
(
x
t
;
1
−
β
t
x
t
−
1
,
β
t
I
)
,
\begin{align} q\left(x_{1: T} \mid x_0\right) &=\prod_{t=1}^T q\left(x_t \mid x_{t-1}\right) \\ q\left(x_t \mid x_{t-1}\right) &=\mathcal{N}\left(x_t ; \sqrt{1-\beta_t} x_{t-1}, \beta_t \mathbf{I}\right), \end{align}
q(x1:T∣x0)q(xt∣xt−1)=t=1∏Tq(xt∣xt−1)=N(xt;1−βtxt−1,βtI),其中
β
t
∈
(
,
1
)
\beta_t\in(0, 1)
βt∈(0,1) 为高斯分布的方差超参, 并满足β
1
<
β
2
<
…
<
β
T
\beta_1 < \beta_2 < \ldots < \beta_T
β1<β2<…<βT. 另外公式 (2) 中为何均值x
t
−
1
x_{t-1}
xt−1 前乘上系数1
−
β
t
x
t
−
1
\sqrt{1-\beta_t} x_{t-1}
1−βtxt−1 的原因将在后面的推导介绍. 上述过程的一个美妙性质是我们可以在任意 time step 下通过 重参数技巧 采样得到x
t
x_t
xt.重参数技巧 (reparameterization trick) 是为了解决随机采样样本这一过程无法求导的问题. 比如要从高斯分布
z
∼
N
(
z
;
μ
,
σ
2
I
)
z \sim \mathcal{N}(z; \mu, \sigma^2\mathbf{I})
z∼N(z;μ,σ2I) 中采样样本z
z
z, 可以通过引入随机变量ϵ
∼
N
(
,
I
)
\epsilon\sim\mathcal{N}(0, \mathbf{I})
ϵ∼N(0,I), 使得z
=
μ
+
σ
⊙
ϵ
z = \mu + \sigma\odot\epsilon
z=μ+σ⊙ϵ, 此时z
z
z 依旧具有随机性, 且服从高斯分布N
(
μ
,
σ
2
I
)
\mathcal{N}(\mu, \sigma^2\mathbf{I})
N(μ,σ2I), 同时μ
\mu
μ 与σ
\sigma
σ (通常由网络生成) 可导.简要了解了重参数技巧后, 再回到上面通过公式 (2) 采样
x
t
x_t
xt 的方法, 即生成随机变量ϵ
t
∼
N
(
,
I
)
\epsilon_t\sim\mathcal{N}(0, \mathbf{I})
ϵt∼N(0,I),
然后令α
t
=
1
−
β
t
\alpha_t = 1 - \beta_t
αt=1−βt, 以及α
t
‾
=
∏
i
=
1
T
α
t
\overline{\alpha_t} = \prod_{i=1}^{T}\alpha_t
αt=∏i=1Tαt, 从而可以得到:x
t
=
1
−
β
t
x
t
−
1
+
β
t
ϵ
1
where
ϵ
1
,
ϵ
2
,
…
∼
N
(
,
I
)
,
reparameter trick
;
=
a
t
x
t
−
1
+
1
−
α
t
ϵ
1
=
a
t
(
a
t
−
1
x
t
−
2
+
1
−
α
t
−
1
ϵ
2
)
+
1
−
α
t
ϵ
1
=
a
t
a
t
−
1
x
t
−
2
+
(
a
t
(
1
−
α
t
−
1
)
ϵ
2
+
1
−
α
t
ϵ
1
)
=
a
t
a
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
ϵ
ˉ
2
where
ϵ
ˉ
2
∼
N
(
,
I
)
;
=
…
=
α
ˉ
t
x
+
1
−
α
ˉ
t
ϵ
ˉ
t
.
\begin{align} x_t &= \sqrt{1 - \beta_t} x_{t-1}+\beta_t \epsilon_1 \quad \text { where } \; \epsilon_1, \epsilon_2, \ldots \sim \mathcal{N}(0, \mathbf{I}), \; \text{reparameter trick} ; \nonumber \\ &=\sqrt{a_t} x_{t-1}+\sqrt{1-\alpha_t} \epsilon_1\nonumber \\ &=\sqrt{a_t}\left(\sqrt{a_{t-1}} x_{t-2}+\sqrt{1-\alpha_{t-1}} \epsilon_2\right)+\sqrt{1-\alpha_t} \epsilon_1 \nonumber \\ &=\sqrt{a_t a_{t-1}} x_{t-2}+\left(\sqrt{a_t\left(1-\alpha_{t-1}\right)} \epsilon_2+\sqrt{1-\alpha_t} \epsilon_1\right) \tag{3-1} \\ &=\sqrt{a_t a_{t-1}} x_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} \bar{\epsilon}_2 \quad \text { where } \quad \bar{\epsilon}_2 \sim \mathcal{N}(0, \mathbf{I}) ; \tag{3-2} \\ &=\ldots \nonumber \\ &=\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \bar{\epsilon}_t. \end{align}
xt=1−βtxt−1+βtϵ1 where ϵ1,ϵ2,…∼N(0,I),reparameter trick;=atxt−1+1−αtϵ1=at(at−1xt−2+1−αt−1ϵ2)+1−αtϵ1=atat−1xt−2+(at(1−αt−1)ϵ2+1−αtϵ1)=atat−1xt−2+1−αtαt−1ϵˉ2 where ϵˉ2∼N(0,I);=…=αˉtx0+1−αˉtϵˉt.(3-1)(3-2)其中公式 (3-1) 到公式 (3-2) 的推导是由于独立高斯分布的可见性, 有
N
(
,
σ
1
2
I
)
+
N
(
,
σ
2
2
I
)
∼
N
(
,
(
σ
1
2
+
σ
2
2
)
I
)
\mathcal{N}\left(0, \sigma_1^2\mathbf{I}\right) +\mathcal{N}\left(0,\sigma_2^2 \mathbf{I}\right)\sim\mathcal{N}\left(0, \left(\sigma_1^2 + \sigma_2^2\right)\mathbf{I}\right)
N(0,σ12I)+N(0,σ22I)∼N(0,(σ12创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!