首页IT科技怎样使用yolo进行目标检测(YOLOV7改进–添加CBAM注意力机制)

怎样使用yolo进行目标检测(YOLOV7改进–添加CBAM注意力机制)

时间2025-06-20 14:13:31分类IT科技浏览3860
导读:因为项目需要,尝试在yolov7上加入CBAM注意力机制,看看能不能提升点性能。之前有在yolov5上添加CBAM的经验,所以直接把yolov5中的CBAM搬过来,废话不多说,直接看代码吧!...

因为项目需要            ,尝试在yolov7上加入CBAM注意力机制                 ,看看能不能提升点性能            。之前有在yolov5上添加CBAM的经验      ,所以直接把yolov5中的CBAM搬过来      ,废话不多说                 ,直接看代码吧!

CBAM注意力机制

首先           ,介绍一下CBAM注意力机制:

论文来源:https://arxiv.org/pdf/1807.06521.pdf

Convolutional Block Attention Module (CBAM)由两个模块构成      ,分别为通道注意力(CAM)和空间注意力模块(SAM)                  ,CAM可以使网络关注图像的前景           ,使网络更加关注有意义的gt区域,而SAM可以让网络关注到整张图片中富含上下文信息的位置                 。这两个模块即插即用                  ,建议串行加入到网络中(论文里面是串行比并行好                 ,在博主的数据集下,并行和串行效果不明显            ,博主认为特征融合没有苛刻的要求                 ,视使用的数据集而定      ,怎么连效果好就怎么连)            ,下面的展示的代码是串行方法      。

代码

在commen.py中添加CBAM模块

这部分代码同yolov5的一样                 ,直接拿来用!

class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) self.relu = nn.ReLU() self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.f2(self.relu(self.f1(self.avg_pool(x)))) max_out = self.f2(self.relu(self.f1(self.max_pool(x)))) out = self.sigmoid(avg_out + max_out) return out class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() assert kernel_size in (3, 7), kernel size must be 3 or 7 padding = 3 if kernel_size == 7 else 1 self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv(x) return self.sigmoid(x) class CBAM(nn.Module): # Standard convolution def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups super(CBAM, self).__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = nn.Hardswish() if act else nn.Identity() self.ca = ChannelAttention(c2) self.sa = SpatialAttention() def forward(self, x): x = self.act(self.bn(self.conv(x))) x = self.ca(x) * x x = self.sa(x) * x return x def fuseforward(self, x): return self.act(self.conv(x))

在yolo.py中添加CBAM模块名

找到yolo.py第459行      ,加入CBAM模块名      。

if m in [nn.Conv2d, Conv, RobustConv, RobustConv2, DWConv, GhostConv, RepConv, RepConv_OREPA, DownC, SPP, SPPF, SPPCSPC, GhostSPPCSPC, MixConv2d, Focus, Stem, GhostStem, CrossConv, Bottleneck, BottleneckCSPA, BottleneckCSPB, BottleneckCSPC, RepBottleneck, RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC, Res, ResCSPA, ResCSPB, ResCSPC, RepRes, RepResCSPA, RepResCSPB, RepResCSPC, ResX, ResXCSPA, ResXCSPB, ResXCSPC, RepResX, RepResXCSPA, RepResXCSPB, RepResXCSPC, Ghost, GhostCSPA, GhostCSPB, GhostCSPC, SwinTransformerBlock, STCSPA, STCSPB, STCSPC, SwinTransformer2Block, ST2CSPA, ST2CSPB, ST2CSPC, CBAM]: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) args = [c1, c2, *args[1:]] if m in [DownC, SPPCSPC, GhostSPPCSPC, BottleneckCSPA, BottleneckCSPB, BottleneckCSPC, RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC, ResCSPA, ResCSPB, ResCSPC, RepResCSPA, RepResCSPB, RepResCSPC, ResXCSPA, ResXCSPB, ResXCSPC, RepResXCSPA, RepResXCSPB, RepResXCSPC, GhostCSPA, GhostCSPB, GhostCSPC, STCSPA, STCSPB, STCSPC, ST2CSPA, ST2CSPB, ST2CSPC]: args.insert(2, n) # number of repeats n = 1

在cfg文件中添加CBAM信息

这里以添加到backbone为例      ,将Conv替换成CBAM即可                 ,同样也可在FPN里替换                 。

# parameters nc: 80 # number of classes depth_multiple: 1.0 # model depth multiple width_multiple: 1.0 # layer channel multiple # anchors anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 - [116,90, 156,198, 373,326] # P5/32 backbone: # [from, number, module, args] c2, k=1, s=1, p=None, g=1, act=True # [[-1, 1, Conv, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 0-P1/2 [[-1, 1, CBAM, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 0-P1/2 # [-1, 1, Conv, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 1-P2/4 [-1, 1, CBAM, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 1-P2/4 [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 7 [-1, 1, MP, []], # 8-P3/8 [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 14 [-1, 1, MP, []], # 15-P4/16 [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 21 [-1, 1, MP, []], # 22-P5/32 [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [512, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 28 ] # yolov7-tiny head head: [[-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, SP, [5]], [-2, 1, SP, [9]], [-3, 1, SP, [13]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -7], 1, Concat, [1]], [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 37 [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, nn.Upsample, [None, 2, nearest]], [21, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P4 [[-1, -2], 1, Concat, [1]], [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 47 [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, nn.Upsample, [None, 2, nearest]], [14, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P3 [[-1, -2], 1, Concat, [1]], [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 57 [-1, 1, Conv, [128, 3, 2, None, 1, nn.LeakyReLU(0.1)]], [[-1, 47], 1, Concat, [1]], [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 65 [-1, 1, Conv, [256, 3, 2, None, 1, nn.LeakyReLU(0.1)]], [[-1, 37], 1, Concat, [1]], [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[-1, -2, -3, -4], 1, Concat, [1]], [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 73 [57, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [65, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [73, 1, Conv, [512, 3, 1, None, 1, nn.LeakyReLU(0.1)]], [[74,75,76], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ]
声明:本站所有文章           ,如无特殊说明或标注      ,均为本站原创发布           。任何个人或组织                  ,在未征得本站同意时           ,禁止复制            、盗用                 、采集      、发布本站内容到任何网站      、书籍等各类媒体平台      。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理                  。

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
python任意数量参数(python中的re.compile函数有何用法?)