首页IT科技gumbel分布的方差(Gumbel-Softmax完全解析)

gumbel分布的方差(Gumbel-Softmax完全解析)

时间2025-08-03 18:13:51分类IT科技浏览5402
导读:写在前面 本文对大部分人来说可能仅仅起到科普的作用,因为Gumbel-Max仅在部分领域会用到,例如GAN、VAE等。笔者是在研究EMNLP上的一篇论文时,看到其中有用Gumbel-Softmax公式解决对一个概率分布进行采样无法求导的问题,故想到对Gumbel-Softmax做一个总结,...

写在前面

本文对大部分人来说可能仅仅起到科普的作用                ,因为Gumbel-Max仅在部分领域会用到                    ,例如GAN                、VAE等                。笔者是在研究EMNLP上的一篇论文时        ,看到其中有用Gumbel-Softmax公式解决对一个概率分布进行采样无法求导的问题            ,故想到对Gumbel-Softmax做一个总结                    ,由此写下本文

为什么我们需要Gumbel-Softmax ?

假设现在我们有一个离散随机变量

Z

Z

Z

的分布

p

1

=

p

(

Z

=

1

)

=

π

1

p

2

=

p

(

Z

=

2

)

=

π

2

p

3

=

p

(

Z

=

3

)

=

π

3

.

.

.

p

x

=

p

(

Z

=

x

)

=

π

x

p_1 = p(Z=1)=\pi_1\\ p_2 = p(Z=2) = \pi_2\\ p_3 = p(Z=3) = \pi_3\\ ...\\ p_x = p(Z=x) = \pi_x\\

p1=p(Z=1)=π1p2=p(Z=2)=π2p3=p(Z=3)=π3...px=p(Z=x)=πx
其中           ,

i

π

i

=

1

\sum_i \pi_i=1

iπi=1
                   。我们想根据

p

1

,

p

2

,

.

.

.

,

p

x

p_1,p_2,...,p_x

p1,p2,...,px
的概率采样得到一系列离散

z

z

z
的值        。但是这么做有一个问题        ,我们采样出来的

z

z

z
只有值                     ,没有生成

z

z

z
的式子
            。例如我们要求

Z

Z

Z

的期望              ,那么就有公式

E

(

Z

)

=

p

1

+

2

p

2

+

+

x

p

x

\mathbb{E}(Z) = p_1 + 2p_2 + \cdots +xp_x

E(Z)=p1+2p2++xpx

Z

Z

Z

p

1

,

p

2

,

.

.

.

,

p

x

p_1,p_2,...,p_x

p1,p2,...,px
的导数都很清楚                    。但是现在我们的需求是采样一些具体的

z

z

z
值    ,采样这个操作没有任何公式                      ,因此也就无法求导           。于是一个很自然的想法就产生了                 ,我们能不能给一个

p

1

,

p

2

,

.

.

.

,

p

z

p_1,p_2,...,p_z

p1,p2,...,pz
为参数的公式,让这个公式返回的结果是

z

z

z
采样的结果呢?

Gumbel-Softmax

一般来说

π

i

\pi_i

πi是通过神经网络预测对于类别

i

i

i
的概率                   ,这在分类问题中非常常见                    ,假设我们将一个样本送入模型    ,最后输出的概率分布为

[

0.2

,

0.4

,

0.1

,

0.2

,

0.1

]

[0.2, 0.4,0.1,0.2,0.1]

[0.2,0.4,0.1,0.2,0.1]
                ,表明这是一个5分类问题                    ,其中概率最大的是第2类        ,到这一步            ,我们直接通过argmax就能获得结果了                    ,但现在我们不是预测问题           ,而是一个采样问题        。对于模型来说        ,直接取出概率最大的就可以了                     ,但对我们来说              ,每个类别都是有一定概率的    ,我们想根据这个概率来进行采样                      ,而不是直接简单无脑的输出概率最大的值

最常见的采样

z

\mathbf{z}

z

的onehot公式为

z

=

onehot

(

max

{

i

π

1

+

π

2

+

+

π

i

1

u

}

)

(1)

\mathbf{z} = \text{onehot}(\max \{i\mid \pi_1 + \pi_2+\cdots +\pi_{i-1} \leq u\})\tag{1}

z=onehot(max{iπ1+π2++πi1u})(1)
其中

i

=

1

,

2

,

.

.

,

x

i=1,2,..,x

i=1,2,..,x
是类别的下标                 ,随机变量

u

u

u
服从均匀分布

U

(

,

1

)

U(0,1)

U(0,1)

上面这个过程实际上是很巧妙的,我们将概率分布从前往后不断加起来                   ,当加到

π

i

\pi_i

πi时超过了某个随机值$ 0\leq u \leq 1

                    ,

    ,

                ,那么这一次随机采样过程                    ,

        ,            ,
z

就被随机采样为第

i$类                    ,最后通过一个onehot变换

但是上述公式存在一个致命的问题:max函数是不可导的

Gumbel-Max Trick

Gumbel-Max技巧就是解决max函数不可导问题的           ,我们可以用argmax替换max        ,即

z

=

onehot

(

argmax

i

{

g

i

+

log

π

i

}

)

(2)

\mathbf{z} = \text{onehot}(\mathop{\text{argmax}}\limits_{i} \{g_i + \log \pi_i\})\tag{2}

z=onehot(iargmax{gi+logπi})(2)
其中                     ,

g

i

=

log

(

log

(

u

i

)

)

,

u

i

U

(

,

1

)

g_i=-\log(-\log(u_i)), u_i \sim U(0,1)

gi=log(log(ui)),uiU(0,1)
              ,这一项名为Gumbel噪声    ,或者叫Gumbel分布                      ,目的是使得

z

\mathbf{z}

z
的返回结果不固定

可以看到式

(

2

)

(2)

(2)的整个过程中                 ,不可导的部分只有argmax,实际上我们可以用可导的softmax函数                   ,在参数

τ

\tau

τ
的控制下逼近argmax                    ,最终

z

i

z_i

zi

的公式为

z

i

=

exp

(

g

i

+

log

π

i

τ

)

j

x

exp

(

g

j

+

log

π

j

τ

)

(3)

z_i = \frac{\exp(\frac{g_i + \log \pi_i}{\tau})}{\sum_{j}^x\exp(\frac{g_j + \log \pi_j}{\tau})}\tag{3}

zi=jxexp(τgj+logπj)exp(τgi+logπi)(3)
其中    ,

τ

\tau

τ
越小

(

τ

)

(\tau \to 0)

(τ0)
                ,整个softmax越光滑逼近argmax                    ,并且

z

=

{

z

i

i

=

1

,

2

,

.

.

.

,

x

}

\mathbf{z} = \{z_i\mid i=1,2,...,x\}

z={zii=1,2,...,x}
也越接近onehot向量;

τ

\tau

τ
越大

(

τ

)

(\tau \to \infty)

(τ)
        ,

z

\mathbf{z}

z
向量越接近于均匀分布

总结

整个过程相当于我们把不可导的取样过程            ,从

z

\mathbf{z}

z本身转移到了求

z

\mathbf{z}

z
的公式中的一项

g

i

g_i

gi
中                    ,而

g

i

g_i

gi
本身不依赖

p

1

,

.

.

,

p

x

p_1,..,p_x

p1,..,px
           ,所以

z

z

z

p

1

,

.

.

.

,

p

x

p_1,...,p_x

p1,...,px
就可以到了        ,而且我们得到的

z

\mathbf{z}

z
仍然是离散概率分布的采样                     。这种采样过程转嫁的技巧有一个专有名词                     ,叫重参数化技巧(Reparameterization Trick)

References

What is Gumbel-Softmax Gumbel-Softmax Trick和Gumbel分布

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
织梦自定义模型调用(织梦自由列表freelist调用增加排序方法) pytorch gpu利用率低(关于CPU和GPU版本共存下的安装Pytorch(跑YOLO模型))