首页IT科技扩散模型求解复杂反应(扩散模型 (Diffusion Model) 简要介绍与源码分析)

扩散模型求解复杂反应(扩散模型 (Diffusion Model) 简要介绍与源码分析)

时间2025-05-05 16:42:47分类IT科技浏览5159
导读:扩散模型 (Diffusion Model 简要介绍与源码分析 前言...

扩散模型 (Diffusion Model) 简要介绍与源码分析

前言

近期同事分享了 Diffusion Model, 这才发现生成模型的发展已经到了如此惊人的地步, OpenAI 推出的 Dall-E 2 可以根据文本描述生成极为逼真的图像, 质量之高直让人惊呼哇塞. 今早公众号给我推送了一篇关于 Stability AI 公司的报道, 他们推出的 AI 文生图扩散模型 Stable Diffusion 已开源, 能够在消费级显卡上实现 Dall-E 2 级别的图像生成, 效率提升了 30 倍.

于是找到他们的开源产品体验了一把, 在线体验地址在 https://huggingface.co/spaces/stabilityai/stable-diffusion (开源代码在 Github 上: https://github.com/CompVis/stable-diffusion), 在搜索框中输入 “A dog flying in the sky           ” (一只狗在天空飞翔), 生成效果如下:

Amazing! 当然, 不是每一张图片都符合预期, 但好在可以生成无数张图片, 其中总有效果好的. 在震惊之余, 不免对 Diffusion Model (扩散模型) 背后的原理感兴趣, 就想看看是怎么实现的.

当时同事分享时, PPT 上那一堆堆公式扑面而来, 把我给整懵圈了, 但还是得撑起下巴, 表现出似有所悟            、深以为然的样子, 在讲到关键处不由暗暗点头以表示理解和赞许. 后面花了个周末专门学习了一下, 公式推导+代码分析, 感觉终于了解了基本概念, 于是记录下来形成此文, 不敢说自己完全懂了, 毕竟我不做这个方向, 但回过头去看 PPT 上的公式就不再发怵了.

广而告之

可以在微信中搜索 “珍妮的算法之路                 ” 或者 “world4458      ” 关注我的微信公众号, 可以及时获取最新原创技术文章更新.

另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.

总览

本文对 Diffusion Model 扩散模型的原理进行简要介绍, 然后对源码进行分析. 扩散模型的实现有多种形式, 本文关注的是 DDPM (denoising diffusion probabilistic models). 在介绍完基本原理后, 对作者释放的 Tensorflow 源码进行分析, 加深对各种公式的理解.

参考文章

在理解扩散模型的路上, 受到下面这些文章的启发, 强烈推荐阅读:

Lilian 的博客, 内容非常非常详实, 干货十足, 而且每篇文章都极其用心, 向大佬学习: What are Diffusion Models? ewrfcas 的知乎, 公式推导补充了更多的细节: 由浅入深了解Diffusion Model Lilian 的博客, 介绍变分自动编码器 VAE: From Autoencoder to Beta-VAE, Diffusion Model 需要从分布中随机采样样本, 该过程无法求导, 需要使用到 VAE 中介绍的重参数技巧. Denoising Diffusion Probabilistic Models 论文, 其 TF 源码位于: https://github.com/hojonathanho/diffusion, 源码介绍以该版本为主 PyTorch 的开源实现: https://github.com/lucidrains/denoising-diffusion-pytorch, 核心逻辑和上面 Tensorflow 版本是一致的, Stable Diffusion 参考的是 pytorch 版本的代码.

扩散模型介绍

基本原理

Diffusion Model (扩散模型) 是一类生成模型, 和 VAE (Variational Autoencoder, 变分自动编码器), GAN (Generative Adversarial Network, 生成对抗网络) 等生成网络不同的是, 扩散模型在前向阶段对图像逐步施加噪声, 直至图像被破坏变成完全的高斯噪声, 然后在逆向阶段学习从高斯噪声还原为原始图像的过程.

具体来说, 前向阶段在原始图像

x

\mathbf{x}_0

x0 上逐步增加噪声, 每一步得到的图像

x

t

\mathbf{x}_t

xt
只和上一步的结果

x

t

1

\mathbf{x}_{t - 1}

xt1
相关, 直至第

T

T

T
步的图像

x

T

\mathbf{x}_T

xT
变为纯高斯噪声. 前向阶段图示如下:

而逆向阶段则是不断去除噪声的过程, 首先给定高斯噪声

x

T

\mathbf{x}_T

xT, 通过逐步去噪, 直至最终将原图像

x

\mathbf{x}_0

x0
给恢复出来, 逆向阶段图示如下:

模型训练完成后, 只要给定高斯随机噪声, 就可以生成一张从未见过的图像. 下面分别介绍前向阶段和逆向阶段, 只列出重要公式,

前向阶段

由于前向过程中图像

x

t

\mathbf{x}_t

xt 只和上一时刻的

x

t

1

\mathbf{x}_{t - 1}

xt1
有关, 该过程可以视为马尔科夫过程, 满足:

q

(

x

1

:

T

x

)

=

t

=

1

T

q

(

x

t

x

t

1

)

q

(

x

t

x

t

1

)

=

N

(

x

t

;

1

β

t

x

t

1

,

β

t

I

)

,

\begin{align} q\left(x_{1: T} \mid x_0\right) &=\prod_{t=1}^T q\left(x_t \mid x_{t-1}\right) \\ q\left(x_t \mid x_{t-1}\right) &=\mathcal{N}\left(x_t ; \sqrt{1-\beta_t} x_{t-1}, \beta_t \mathbf{I}\right), \end{align}

q(x1:Tx0)q(xtxt1)=t=1Tq(xtxt1)=N(xt;1βtxt1,βtI),

其中

β

t

(

,

1

)

\beta_t\in(0, 1)

βt(0,1) 为高斯分布的方差超参, 并满足

β

1

<

β

2

<

<

β

T

\beta_1 < \beta_2 < \ldots < \beta_T

β1<β2<<βT
. 另外公式 (2) 中为何均值

x

t

1

x_{t-1}

xt1
前乘上系数

1

β

t

x

t

1

\sqrt{1-\beta_t} x_{t-1}

1βtxt1
的原因将在后面的推导介绍. 上述过程的一个美妙性质是我们可以在任意 time step 下通过 重参数技巧 采样得到

x

t

x_t

xt
.

重参数技巧 (reparameterization trick) 是为了解决随机采样样本这一过程无法求导的问题. 比如要从高斯分布

z

N

(

z

;

μ

,

σ

2

I

)

z \sim \mathcal{N}(z; \mu, \sigma^2\mathbf{I})

zN(z;μ,σ2I) 中采样样本

z

z

z
, 可以通过引入随机变量

ϵ

N

(

,

I

)

\epsilon\sim\mathcal{N}(0, \mathbf{I})

ϵN(0,I)
, 使得

z

=

μ

+

σ

ϵ

z = \mu + \sigma\odot\epsilon

z=μ+σϵ
, 此时

z

z

z
依旧具有随机性, 且服从高斯分布

N

(

μ

,

σ

2

I

)

\mathcal{N}(\mu, \sigma^2\mathbf{I})

N(μ,σ2I)
, 同时

μ

\mu

μ

σ

\sigma

σ
(通常由网络生成) 可导.

简要了解了重参数技巧后, 再回到上面通过公式 (2) 采样

x

t

x_t

xt 的方法, 即生成随机变量

ϵ

t

N

(

,

I

)

\epsilon_t\sim\mathcal{N}(0, \mathbf{I})

ϵtN(0,I)

,

然后令

α

t

=

1

β

t

\alpha_t = 1 - \beta_t

αt=1βt
, 以及

α

t

=

i

=

1

T

α

t

\overline{\alpha_t} = \prod_{i=1}^{T}\alpha_t

αt=i=1Tαt
, 从而可以得到:

x

t

=

1

β

t

x

t

1

+

β

t

ϵ

1

 where 

ϵ

1

,

ϵ

2

,

N

(

,

I

)

,

reparameter trick

;

=

a

t

x

t

1

+

1

α

t

ϵ

1

=

a

t

(

a

t

1

x

t

2

+

1

α

t

1

ϵ

2

)

+

1

α

t

ϵ

1

=

a

t

a

t

1

x

t

2

+

(

a

t

(

1

α

t

1

)

ϵ

2

+

1

α

t

ϵ

1

)

=

a

t

a

t

1

x

t

2

+

1

α

t

α

t

1

ϵ

ˉ

2

 where 

ϵ

ˉ

2

N

(

,

I

)

;

=

=

α

ˉ

t

x

+

1

α

ˉ

t

ϵ

ˉ

t

.

\begin{align} x_t &= \sqrt{1 - \beta_t} x_{t-1}+\beta_t \epsilon_1 \quad \text { where } \; \epsilon_1, \epsilon_2, \ldots \sim \mathcal{N}(0, \mathbf{I}), \; \text{reparameter trick} ; \nonumber \\ &=\sqrt{a_t} x_{t-1}+\sqrt{1-\alpha_t} \epsilon_1\nonumber \\ &=\sqrt{a_t}\left(\sqrt{a_{t-1}} x_{t-2}+\sqrt{1-\alpha_{t-1}} \epsilon_2\right)+\sqrt{1-\alpha_t} \epsilon_1 \nonumber \\ &=\sqrt{a_t a_{t-1}} x_{t-2}+\left(\sqrt{a_t\left(1-\alpha_{t-1}\right)} \epsilon_2+\sqrt{1-\alpha_t} \epsilon_1\right) \tag{3-1} \\ &=\sqrt{a_t a_{t-1}} x_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} \bar{\epsilon}_2 \quad \text { where } \quad \bar{\epsilon}_2 \sim \mathcal{N}(0, \mathbf{I}) ; \tag{3-2} \\ &=\ldots \nonumber \\ &=\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} \bar{\epsilon}_t. \end{align}

xt=1βtxt1+βtϵ1 where ϵ1,ϵ2,N(0,I),reparameter trick;=atxt1+1αtϵ1=at(at1xt2+1αt1ϵ2)+1αtϵ1=atat1xt2+(at(1αt1)ϵ2+1αtϵ1)=atat1xt2+1αtαt1ϵˉ2 where ϵˉ2N(0,I);==αˉtx0+1αˉtϵˉt.(3-1)(3-2)

其中公式 (3-1) 到公式 (3-2) 的推导是由于独立高斯分布的可见性, 有

N

(

,

σ

1

2

I

)

+

N

(

,

σ

2

2

I

)

N

(

,

(

σ

1

2

+

σ

2

2

)

I

)

\mathcal{N}\left(0, \sigma_1^2\mathbf{I}\right) +\mathcal{N}\left(0,\sigma_2^2 \mathbf{I}\right)\sim\mathcal{N}\left(0, \left(\sigma_1^2 + \sigma_2^2\right)\mathbf{I}\right)

N(0,σ12I)+N(0,σ22I)N(0,(σ12

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
选择歌曲林子祥叶倩文(選擇 ohmyzsh 讓 Terminal 更好用 » 社区 | Ruby China 選擇 ohmyzsh 讓 Terminal 更好用 » 社区 | Ruby China 選擇 ohmyzs)