首页IT科技deepmodling(DeepLabV3+:Mobilenetv2的改进以及浅层特征和深层特征的融合)

deepmodling(DeepLabV3+:Mobilenetv2的改进以及浅层特征和深层特征的融合)

时间2025-06-15 14:28:14分类IT科技浏览3971
导读:目录...

目录

Mobilenetv2的改进

浅层特征和深层特征的融合

完整代码

参考资料

Mobilenetv2的改进

在DeeplabV3当中            ,一般不会5次下采样                  ,可选的有3次下采样和4次下采样            。因为要进行五次下采样的话会损失较多的信息                  。

在这里mobilenetv2会从之前写好的模块中得到      ,但注意的是            ,我们在这里获得的特征是[-1]                  ,也就是最后的1x1卷积不取      ,只取循环完后的模型      。

down_idx是InvertedResidual进行的次数      。

# t, c, n, s [1, 16, 1, 1],  [6, 24, 2, 2],    2 [6, 32, 3, 2],    4 [6, 64, 4, 2],    7   [6, 96, 3, 1], [6, 160, 3, 2],   14 [6, 320, 1, 1], 

根据下采样的不同      ,当downsample_factor=8时                  ,进行3次下采样            ,对倒数两次      ,步长为2的InvertedResidual进行参数的修改                  ,让步长变为1            ,膨胀系数为2                  。

当downsample_factor=16时,进行4次下采样                  ,只需对最后一次进行参数的修改            。

import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from net.mobilenetv2 import mobilenetv2 from net.ASPP import ASPP class MobileNetV2(nn.Module): def __init__(self, downsample_factor=8, pretrained=True): super(MobileNetV2, self).__init__() model = mobilenetv2(pretrained) self.features = model.features[:-1] self.total_idx = len(self.features) self.down_idx = [2, 4, 7, 14] if downsample_factor == 8: for i in range(self.down_idx[-2], self.down_idx[-1]): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=4) ) elif downsample_factor == 16: for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ if classname.find(Conv) != -1: if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate//2, dilate//2) m.padding = (dilate//2, dilate//2) else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def forward(self, x): low_level_features = self.features[:4](x) x = self.features[4:](low_level_features) return low_level_features, x

forward当中                  ,会输出两个特征层,一个是浅层特征层            ,具有浅层的语义信息;另一个是深层特征层                  ,具有深层的语义信息      。

浅层特征和深层特征的融合

 具有高语义信息的部分先进行上采样      ,低语义信息的特征层进行1x1卷积            ,二者进行特征融合                  ,再进行3x3卷积进行特征提取

self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)

这一步就是获得那个绿色的特征层;

low_level_features = self.shortcut_conv(low_level_features)

从这里将是对浅层特征的初步处理(1x1卷积);

x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode=bilinear, align_corners=True) x = self.cat_conv(torch.cat((x, low_level_features), dim=1))

上采样后进行特征融合      ,这样我们输入和输出的大小才相同      ,每一个像素点才能进行预测;

完整代码

# deeplabv3plus.py import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from net.xception import xception from net.mobilenetv2 import mobilenetv2 from net.ASPP import ASPP class MobileNetV2(nn.Module): def __init__(self, downsample_factor=8, pretrained=True): super(MobileNetV2, self).__init__() model = mobilenetv2(pretrained) self.features = model.features[:-1] self.total_idx = len(self.features) self.down_idx = [2, 4, 7, 14] if downsample_factor == 8: for i in range(self.down_idx[-2], self.down_idx[-1]): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=4) ) elif downsample_factor == 16: for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ if classname.find(Conv) != -1: if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate//2, dilate//2) m.padding = (dilate//2, dilate//2) else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def forward(self, x): low_level_features = self.features[:4](x) x = self.features[4:](low_level_features) return low_level_features, x class DeepLab(nn.Module): def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16): super(DeepLab, self).__init__() if backbone=="xception": # 获得两个特征层:浅层特征 主干部分 self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained) in_channels = 2048 low_level_channels = 256 elif backbone=="mobilenet": # 获得两个特征层:浅层特征 主干部分 self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained) in_channels = 320 low_level_channels = 24 else: raise ValueError(Unsupported backbone - `{}`, Use mobilenet, xception..format(backbone)) # ASPP特征提取模块 # 利用不同膨胀率的膨胀卷积进行特征提取 self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor) # 浅层特征边 self.shortcut_conv = nn.Sequential( nn.Conv2d(low_level_channels, 48, 1), nn.BatchNorm2d(48), nn.ReLU(inplace=True) ) self.cat_conv = nn.Sequential( nn.Conv2d(48+256, 256, kernel_size=(3,3), stride=(1,1), padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Conv2d(256, 256, kernel_size=(3,3), stride=(1,1), padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), ) self.cls_conv = nn.Conv2d(256, num_classes, kernel_size=(1,1), stride=(1,1)) def forward(self, x): H, W = x.size(2), x.size(3) # 获得两个特征层                  ,low_level_features: 浅层特征-进行卷积处理 # x : 主干部分-利用ASPP结构进行加强特征提取 low_level_features, x = self.backbone(x) x = self.aspp(x) low_level_features = self.shortcut_conv(low_level_features) # 将加强特征边上采样            ,与浅层特征堆叠后利用卷积进行特征提取 x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode=bilinear, align_corners=True) x = self.cat_conv(torch.cat((x, low_level_features), dim=1)) x = self.cls_conv(x) x = F.interpolate(x, size=(H, W), mode=bilinear, align_corners=True) return x

参考资料

DeepLabV3-/论文精选 at main · Auorui/DeepLabV3- (github.com)

(6条消息) 憨批的语义分割重制版9——Pytorch 搭建自己的DeeplabV3+语义分割平台_Bubbliiiing的博客-CSDN博客

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
热力图在哪里看(YOLOv7、YOLOv5改进之打印热力图可视化:适用于自定义模型,丰富实验数据) 自动生成文案的软件有哪些(自动生成文案的软件有哪些)