gumbel分布的方差(Gumbel-Softmax完全解析)
写在前面
本文对大部分人来说可能仅仅起到科普的作用 ,因为Gumbel-Max仅在部分领域会用到 ,例如GAN 、VAE等 。笔者是在研究EMNLP上的一篇论文时 ,看到其中有用Gumbel-Softmax公式解决对一个概率分布进行采样无法求导的问题 ,故想到对Gumbel-Softmax做一个总结 ,由此写下本文
为什么我们需要Gumbel-Softmax ?
假设现在我们有一个离散随机变量
Z
Z
Z的分布
p
1
=
p
(
Z
=
1
)
=
π
1
p
2
=
p
(
Z
=
2
)
=
π
2
p
3
=
p
(
Z
=
3
)
=
π
3
.
.
.
p
x
=
p
(
Z
=
x
)
=
π
x
p_1 = p(Z=1)=\pi_1\\ p_2 = p(Z=2) = \pi_2\\ p_3 = p(Z=3) = \pi_3\\ ...\\ p_x = p(Z=x) = \pi_x\\
p1=p(Z=1)=π1p2=p(Z=2)=π2p3=p(Z=3)=π3...px=p(Z=x)=πx 其中 ,∑
i
π
i
=
1
\sum_i \pi_i=1
∑iπi=1 。我们想根据p
1
,
p
2
,
.
.
.
,
p
x
p_1,p_2,...,p_x
p1,p2,...,px的概率采样得到一系列离散z
z
z的值 。但是这么做有一个问题 ,我们采样出来的z
z
z只有值 ,没有生成z
z
z的式子 。例如我们要求Z
Z
Z的期望 ,那么就有公式
E
(
Z
)
=
p
1
+
2
p
2
+
⋯
+
x
p
x
\mathbb{E}(Z) = p_1 + 2p_2 + \cdots +xp_x
E(Z)=p1+2p2+⋯+xpxZ
Z
Z对p
1
,
p
2
,
.
.
.
,
p
x
p_1,p_2,...,p_x
p1,p2,...,px的导数都很清楚 。但是现在我们的需求是采样一些具体的z
z
z值 ,采样这个操作没有任何公式 ,因此也就无法求导 。于是一个很自然的想法就产生了 ,我们能不能给一个以p
1
,
p
2
,
.
.
.
,
p
z
p_1,p_2,...,p_z
p1,p2,...,pz为参数的公式,让这个公式返回的结果是z
z
z采样的结果呢?Gumbel-Softmax
一般来说
π
i
\pi_i
πi是通过神经网络预测对于类别i
i
i的概率 ,这在分类问题中非常常见 ,假设我们将一个样本送入模型,最后输出的概率分布为[
0.2
,
0.4
,
0.1
,
0.2
,
0.1
]
[0.2, 0.4,0.1,0.2,0.1]
[0.2,0.4,0.1,0.2,0.1] ,表明这是一个5分类问题 ,其中概率最大的是第2类 ,到这一步 ,我们直接通过argmax就能获得结果了 ,但现在我们不是预测问题 ,而是一个采样问题 。对于模型来说 ,直接取出概率最大的就可以了 ,但对我们来说 ,每个类别都是有一定概率的 ,我们想根据这个概率来进行采样 ,而不是直接简单无脑的输出概率最大的值最常见的采样
z
\mathbf{z}
z的onehot公式为
z
=
onehot
(
max
{
i
∣
π
1
+
π
2
+
⋯
+
π
i
−
1
≤
u
}
)
(1)
\mathbf{z} = \text{onehot}(\max \{i\mid \pi_1 + \pi_2+\cdots +\pi_{i-1} \leq u\})\tag{1}
z=onehot(max{i∣π1+π2+⋯+πi−1≤u})(1) 其中i
=
1
,
2
,
.
.
,
x
i=1,2,..,x
i=1,2,..,x是类别的下标 ,随机变量u
u
u服从均匀分布U
(
,
1
)
U(0,1)
U(0,1)上面这个过程实际上是很巧妙的,我们将概率分布从前往后不断加起来 ,当加到
π
i
\pi_i
πi时超过了某个随机值$ 0\leq u \leq 1,
那
么
这
一
次
随
机
采
样
过
程
,
,那么这一次随机采样过程 ,
,那么这一次随机采样过程 ,z就
被
随
机
采
样
为
第
就被随机采样为第
就被随机采样为第i$类 ,最后通过一个onehot变换但是上述公式存在一个致命的问题:max函数是不可导的
Gumbel-Max Trick
Gumbel-Max技巧就是解决max函数不可导问题的 ,我们可以用argmax替换max ,即
z
=
onehot
(
argmax
i
{
g
i
+
log
π
i
}
)
(2)
\mathbf{z} = \text{onehot}(\mathop{\text{argmax}}\limits_{i} \{g_i + \log \pi_i\})\tag{2}
z=onehot(iargmax{gi+logπi})(2) 其中 ,g
i
=
−
log
(
−
log
(
u
i
)
)
,
u
i
∼
U
(
,
1
)
g_i=-\log(-\log(u_i)), u_i \sim U(0,1)
gi=−log(−log(ui)),ui∼U(0,1) ,这一项名为Gumbel噪声 ,或者叫Gumbel分布 ,目的是使得z
\mathbf{z}
z的返回结果不固定可以看到式
(
2
)
(2)
(2)的整个过程中 ,不可导的部分只有argmax,实际上我们可以用可导的softmax函数 ,在参数τ
\tau
τ的控制下逼近argmax ,最终z
i
z_i
zi的公式为
z
i
=
exp
(
g
i
+
log
π
i
τ
)
∑
j
x
exp
(
g
j
+
log
π
j
τ
)
(3)
z_i = \frac{\exp(\frac{g_i + \log \pi_i}{\tau})}{\sum_{j}^x\exp(\frac{g_j + \log \pi_j}{\tau})}\tag{3}
zi=∑jxexp(τgj+logπj)exp(τgi+logπi)(3) 其中,τ
\tau
τ越小(
τ
→
)
(\tau \to 0)
(τ→0) ,整个softmax越光滑逼近argmax ,并且z
=
{
z
i
∣
i
=
1
,
2
,
.
.
.
,
x
}
\mathbf{z} = \{z_i\mid i=1,2,...,x\}
z={zi∣i=1,2,...,x}也越接近onehot向量;τ
\tau
τ越大(
τ
→
∞
)
(\tau \to \infty)
(τ→∞) ,z
\mathbf{z}
z向量越接近于均匀分布总结
整个过程相当于我们把不可导的取样过程 ,从
z
\mathbf{z}
z本身转移到了求z
\mathbf{z}
z的公式中的一项g
i
g_i
gi中 ,而g
i
g_i
gi本身不依赖p
1
,
.
.
,
p
x
p_1,..,p_x
p1,..,px ,所以z
z
z对p
1
,
.
.
.
,
p
x
p_1,...,p_x
p1,...,px就可以到了 ,而且我们得到的z
\mathbf{z}
z仍然是离散概率分布的采样 。这种采样过程转嫁的技巧有一个专有名词 ,叫重参数化技巧(Reparameterization Trick)References
What is Gumbel-Softmax Gumbel-Softmax Trick和Gumbel分布创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!