首页IT科技损失函数loss(关于Focal loss损失函数的代码实现)

损失函数loss(关于Focal loss损失函数的代码实现)

时间2025-05-05 11:44:42分类IT科技浏览3703
导读:引言 场景:使用Bert做一个违规样本分类模型,数据呈现正负样本不均衡,难易样本不均衡等问题,尝试使用Focal loss替换Bert中后半部分的交叉熵损失函数。...

引言

场景:使用Bert做一个违规样本分类模型           ,数据呈现正负样本不均衡               ,难易样本不均衡等问题      ,尝试使用Focal loss替换Bert中后半部分的交叉熵损失函数           。

初衷:由于使用的Bert模型中使用的损失函数为交叉熵损失函数         ,torch.nn.CrossEntropyLoss               ,那么如果能理解实现原理         ,将focal loss在该api基础上实现      ,就可以尽可能少修改原始代码

Focal loss的公式:其中用到的交叉熵损失函数表达式是(3)

F

L

p

t

=

(

1

p

t

)

γ

log

p

t

(1)

FL(p_{t}) = - (1 - p_{t})^{\gamma}\log{p_{t}}\tag{1}

FLpt=(1pt)γlogpt(1)

其中:

p

t

=

{

p

i

f

y

=

1

1

p

o

t

h

e

r

w

i

s

e

(1.1)

p_{t}=\begin{cases} p& if & y = 1 \\ 1-p && otherwise \end{cases}\tag{1.1}

pt={p1pify=1otherwise(1.1)

辅助理解:首先要了解的是交叉熵损失函数可以是二分类也可以是多分类               ,核心就是sigmoid还是softmax作为激活函数            ,那么对应的就是CE和BCE   ,从而focal loss根据CE BCE也就有两种表达                 。

1 focal loss的公式推导过程理解可以参考:寻找解决样本不均衡方法之Focal Loss与GHM - 知乎 (zhihu.com)

2 交叉熵损失函数的推导过程可以参考:交叉熵损失函数 - 知乎 (zhihu.com)

3 CE与BCE的区别:CE Loss 与 BCE Loss 区别 - 知乎 (zhihu.com)

BCE:二分类

L

=

i

=

1

N

(

y

i

log

y

^

i

+

(

1

y

i

)

log

(

1

y

^

i

)

)

(2)

L = -\sum^N_{i=1}(y_{i}\log{\hat{y}_{i}} + (1-y_{i})\log{(1-\hat{y}}_{i}))\tag{2}

L=i=1N(yilogyi+(1yi)log(1yi))(2)

CE:多分类               ,当其是二分类时候与BCE有什么区别可见上面的链接

L

=

i

=

1

N

(

y

i

log

y

^

i

)

(3)

L = -\sum^N_{i=1}(y_{i}\log{\hat{y}_{i}} )\tag{3}

L=i=1N(yilogyi)(3)

pytorch中具体实现方法可以查看:[CrossEntropyLoss — PyTorch 1.12 documentation]

softmax              ,log_softmax,nllloss的表达式: 关于nllloss专门整理一篇介绍    。

σ

(

z

)

j

=

e

z

j

k

=

1

n

e

z

k

(softmax)

\sigma(z)_{j} = \frac{e^{z_{j}}}{\sum_{k=1}^ne^{z_{k}}}\tag{softmax}

σ(z)j=k=1nezkezj(softmax)

l

o

g

s

o

f

t

m

a

x

=

ln

σ

(

z

)

j

logsoftmax = \ln{\sigma(z)_{j}}

logsoftmax=lnσ(z)j

n

l

l

l

o

s

s

=

1

N

k

=

1

N

y

k

(

l

o

g

s

o

f

t

m

a

x

)

nllloss = - \frac{1}{N}\sum_{k=1}^Ny_{k}(logsoftmax)

nllloss=N1k=1Nyk(logsoftmax) 使用pytorch实现focal loss源码如下:(个人觉得比较简练的一个) import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset import torchvision import torchvision.transforms as F from IPython.display import display class FocalLoss(nn.Module): def __init__(self, weight=None, reduction=mean, gamma=0, eps=1e-7): super(FocalLoss, self).__init__() self.gamma = gamma self.eps = eps self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction) def forward(self, input, target): logp = self.ce(input, target) p = torch.exp(-logp) loss = (1 - p) ** self.gamma * logp return loss.mean()

代码来源:Focal Loss代码分析(公式修改版-知乎公式坑) - 知乎 (zhihu.com)

代码实现的原理如下:

pytorch中交叉熵损失函数所有表达式             ,类比(3)

l

o

s

s

(

x

,

c

l

a

s

s

)

=

log

e

x

c

l

a

s

s

j

e

x

j

=

x

c

l

a

s

s

+

log

j

e

x

j

(3)

loss(x,class) = -\log{\frac{e^{x_{class}}}{\sum_{j}e^{x_j}}}= -x_{class} + \log{\sum_{j}e^{x_j}}\tag{3}

loss(x,class)=logjexjexclass=xclass+logjexj(3)

α-balanced交叉熵结合表达式

l

o

s

s

(

x

,

c

l

a

s

s

)

=

α

c

l

a

s

s

(

x

c

l

a

s

s

+

log

j

e

x

j

)

(4)

loss(x,class)= \alpha_{class}*(-x_{class} + \log{\sum_{j}e^{x_j}})\tag{4}

loss(x,class)=αclass(xclass+logjexj)(4)

focal loss表达式:

l

o

s

s

(

x

,

c

l

a

s

s

)

=

(

1

e

x

c

l

a

s

s

j

e

x

j

)

γ

log

e

x

c

l

a

s

s

j

e

x

j

=

(

1

e

x

c

l

a

s

s

j

e

x

j

)

γ

(

x

c

l

a

s

s

+

log

j

e

x

j

)

=

(

1

p

t

)

γ

log

(

p

t

)

(5)

loss(x,class) =(1 - \frac{e^{x_{class}}}{\sum_{j}e^{x_j}})^{\gamma} -\log{\frac{e^{x_{class}}}{\sum_{j}e^{x_j}}} =(1 - \frac{e^{x_{class}}}{\sum_{j}e^{x_j}})^{\gamma}(-x_{class} + \log{\sum_{j}e^{x_j}}) = -(1-p_{t})^{\gamma} \log{(p_{t})}\tag{5}

loss(x,class)=(1jexjexclass)γlogjexjexclass=(1jexjexclass)γ(xclass+logjexj)=(1pt)γlog(pt)(5)

带有alpha平衡参数的focal loss表达式:

l

o

s

s

(

x

,

c

l

a

s

s

)

=

α

t

(

1

p

t

)

γ

log

(

p

t

)

(6)

loss(x,class) = -\alpha_{t}(1-p_{t})^{\gamma} \log{(p_{t})}\tag{6}

loss(x,class)=αt(1pt)γlog(pt)(6)

将CrossEntropyLoss改成Focal Loss

log

p

t

=

n

n

.

C

r

o

s

s

E

n

t

r

o

p

y

L

o

s

s

(

i

n

p

u

t

,

t

a

r

g

e

t

)

(7)

-\log{p_{t}} = nn.CrossEntropyLoss(input, target)\tag{7}

logpt=nn.CrossEntropyLoss(input,target)(7)

那么:

p

t

=

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
流程图的转折箭头怎么连接(流程图拖拽视觉编程–概述)