首页IT科技狗能看懂什么电视(狗都能看懂的CenterNet讲解及代码复现)

狗能看懂什么电视(狗都能看懂的CenterNet讲解及代码复现)

时间2025-10-21 17:00:11分类IT科技浏览6537
导读:论文: https://arxiv.org/abs/1904.07850...

论文: https://arxiv.org/abs/1904.07850

前言

之前博文介绍的目标检测算法如:Yolo系列                    、Faster RCNN等                 ,这些基于先验框去进行预测的目标框的算法                            ,我们称为anchor-base                    。即使是anchor-base中的one-stage算法          ,因为其复杂后处理             ,也常常被人诟病不是真正的end2end算法                           。在目标检测领域中                           ,还有另一种不用基于先验框的模式               ,我们称之为anchor-free        。

anchor-free的定义就很简单了         ,输入一张图片                          ,输出则是一堆检测框的坐标                           、类别和置信度                    ,实现了真正意义上的端到端               。那这篇文章就来介绍一下比较有名的Objects as Points(CenterNet)

网络结构

网络主要分成三个部分Backbone        、DecoderHead

Backbone

在论文中只提了一下     ,是Hourglass                          ,没有详细介绍                            。我个人复现采用的是resnet50                         ,各位读者有兴趣也可以自己替换下            。有关于ResNet的介绍在之前的博客已经讲解了,还没看的同学可以点这里          。Backbone部分我们只取其中最后一个feature map                     ,resnet50经过5次下采样后                             ,最后一个feature map的宽高维度为为原来的1/32     ,通道维度为2048                             。

Decoder

Decoder中采用UpSample + BN + Activation作为一个block                 ,以此堆叠三次作为一个Decoder                。其中CenterNet的UpSample为反卷积                            ,激活函数为ReLU     。需要注意的是          ,三个反卷积的核大小都为4x4             ,卷积核的数目分别为256                           ,128               ,64                              。那么经过Decoder之后         ,feature map的宽高维度则变为原来1/4(比较重要                          ,后面会反复用到)                    ,通道维度为64                     。

对应的代码是:

class CenterNetDecoder(nn.Module): def __init__(self, in_channels, bn_momentum=0.1): super(CenterNetDecoder, self).__init__() self.bn_momentum = bn_momentum self.in_channels = in_channels self.deconv_with_bias = False # h/32, w/32, 2048 -> h/16, w/16, 256 -> h/8, w/8, 128 -> h/4, w/4, 64 self.deconv_layers = self._make_deconv_layer( num_layers=3, num_filters=[256, 128, 64], num_kernels=[4, 4, 4], ) def _make_deconv_layer(self, num_layers, num_filters, num_kernels): layers = [] for i in range(num_layers): kernel = num_kernels[i] num_filter = num_filters[i] layers.append( nn.ConvTranspose2d( in_channels=self.in_channels, out_channels=num_filter, kernel_size=kernel, stride=2, padding=1, output_padding=0, bias=self.deconv_with_bias)) layers.append(nn.BatchNorm2d(num_filter, momentum=self.bn_momentum)) layers.append(nn.ReLU(inplace=True)) self.in_channels = num_filter return nn.Sequential(*layers) def forward(self, x): return self.deconv_layers(x)

Head

CenterNet的Head部分是值得我们说道一下的     ,分成三个组件HeatMap               、WidthHeight以及Offset。三个组件都需要经过64维的Conv + BN + ReLU                          ,然后分别用对应的卷积层输出                         。每个组件的输出都是一个feature map                         ,Head部分是不会改变feature map的尺寸的,所以feature map宽高维度还是输入的1/4                          。物体的中心落在了feature map中那个格点                     ,这个格点就负责存储预测信息    。

HeatMap的最后一个卷积层通道维度为分类数量                             ,卷积核大小为1x1     ,最后需要用sigmoid激活函数处理一下                    。其输出的形式和解码类似于语义分割                           。在物体的中心                 ,它的响应很强                            ,接近于1          ,在背景部分为0        。我们解码的时候             ,在通道维度上进行Argmax                           ,即可得到最终的分类index               。 WidthHeight对应的是检测框宽高               ,因为宽高信息为2个         ,所以其最后一层卷积通道输出维度为2                          ,卷积核大小为1x1                            。 Offset                    ,由于HeatMap的到的响应是基于物体中心的     ,而且相当于输入来说是下采样四倍的                          ,从HeatMap中的到的物体中心是有一点误差的            。所以需要用Offset的结果对物体中心点进行修正                         ,如下图所示,其最后一层的卷积通道维度为2                     ,卷积核大小为1x1          。

Head的对应的代码是:

class CenterNetHead(nn.Module): def __init__(self, num_classes=80, channel=64, bn_momentum=0.1): super(CenterNetHead, self).__init__() # heatmap self.cls_head = nn.Sequential( nn.Conv2d(64, channel, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64, momentum=bn_momentum), nn.ReLU(inplace=True), nn.Conv2d(channel, num_classes, kernel_size=1, stride=1, padding=0), nn.Sigmoid() ) # bounding boxes height and width self.wh_head = nn.Sequential( nn.Conv2d(64, channel, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64, momentum=bn_momentum), nn.ReLU(inplace=True), nn.Conv2d(channel, 2, kernel_size=1, stride=1, padding=0)) # center point offset self.offset_head = nn.Sequential( nn.Conv2d(64, channel, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64, momentum=bn_momentum), nn.ReLU(inplace=True), nn.Conv2d(channel, 2, kernel_size=1, stride=1, padding=0)) def forward(self, x): hm = self.cls_head(x) wh = self.wh_head(x) offset = self.offset_head(x) return hm, wh, offset

数据读取

CenterNet的数据读取比较简单                             。首先                             ,无论是预测还是训练都需要做的就是:resize输入图像     ,最常见的是保持图像宽高比                 ,然后将短边不足的部分进行pad                。至于Ground Truth的坐标信息                            ,只需要将它转换为HeatMap                            、WidthHeight和Offset三个组件即可     。三个组件也是由原图宽高的1/4(对应Decoder部分)大小的feature map进行存储                              。

HeatMap顾名思义就是热力图          ,也即物体中心的响应为1             ,其余地方为0                     。但在训练时                           ,这么做是会让整个数据变得非常稀疏               ,正负样本严重不平衡。所以输入时         ,我们会用高斯函数处理成物体中心为1                          ,其余部分数值慢慢递减                    ,呈正态分布     ,如下图所示                         。这么做的好处可以使得输出更平滑                          ,容易在卷积结构中建模                          。至于高斯核的超参数等如何设置这里就不一一阐述了                         ,有兴趣的读者可以自行查阅    。

WidthHeight则是在对应位置上存储检测框的宽和高,注意这里的宽高也是指检测框原始尺寸的1/4                    。

Offset同理                     ,只是它存储的是偏移量                             ,那么这个偏移量的定义是中心点坐标的小数部分                           。也即

o

f

f

s

e

t

x

=

i

n

t

(

x

1

+

x

2

/

2

)

(

x

1

+

x

2

/

2

)

o

f

f

s

e

t

y

=

i

n

t

(

y

1

+

y

2

/

2

)

(

y

1

+

y

2

/

2

)

offset_x = int(x1 + x2 / 2) - (x1 + x2 / 2)\\ offset_y = int(y1 + y2 / 2) - (y1 + y2 / 2)

offsetx=int(x1+x2/2)(x1+x2/2)offsety=int(y1+y2/2)(y1+y2/2) 在代码实现中还会返回一个mask在loss中进行计算     ,这个mask在物体中心为1                 ,其余地方为0                            ,目的是为了只计算物体中心的WidthHeight和Offset loss          ,其余不是物体中心的预测就不计算loss了        。 def __getitem__(self, index): batch_hm = np.zeros((self.output_shape[0], self.output_shape[1], self.num_classes), dtype=np.float32) batch_wh = np.zeros((self.output_shape[0], self.output_shape[1], 2), dtype=np.float32) batch_offset = np.zeros((self.output_shape[0], self.output_shape[1], 2), dtype=np.float32) batch_offset_mask = np.zeros((self.output_shape[0], self.output_shape[1]), dtype=np.float32) # Read image and bounding boxes image, bboxes = self.parse_annotation(index) if self.is_train: image, bboxes = self.data_augmentation(image, bboxes) # Image preprocess image, bboxes = image_resize(image, self.input_shape, bboxes) image = preprocess_input(image) # Clip bounding boxes clip_bboxes = [] labels = [] for bbox in bboxes: x1, y1, x2, y2, label = bbox if x2 <= x1 or y2 <= y1: # Dont use such boxes as this may cause nan loss. continue x1 = int(np.clip(x1, 0, self.input_shape[1])) y1 = int(np.clip(y1, 0, self.input_shape[0])) x2 = int(np.clip(x2, 0, self.input_shape[1])) y2 = int(np.clip(y2, 0, self.input_shape[0])) # Clipping coordinates between 0 to image dimensions as negative values # or values greater than image dimensions may cause nan loss. clip_bboxes.append([x1, y1, x2, y2]) labels.append(label) bboxes = np.array(clip_bboxes) labels = np.array(labels) if len(bboxes) != 0: labels = np.array(labels, dtype=np.float32) bboxes = np.array(bboxes[:, :4], dtype=np.float32) bboxes[:, [0, 2]] = np.clip(bboxes[:, [0, 2]] / self.stride, a_min=0, a_max=self.output_shape[1]) bboxes[:, [1, 3]] = np.clip(bboxes[:, [1, 3]] / self.stride, a_min=0, a_max=self.output_shape[0]) for i in range(len(labels)): x1, y1, x2, y2 = bboxes[i] cls_id = int(labels[i]) h, w = y2 - y1, x2 - x1 if h > 0 and w > 0: radius = gaussian_radius((math.ceil(h), math.ceil(w))) radius = max(0, int(radius)) # Calculates the feature points of the real box ct = np.array([(x1 + x2) / 2, (y1 + y2) / 2], dtype=np.float32) ct_int = ct.astype(np.int32) # Get gaussian heat map batch_hm[:, :, cls_id] = draw_gaussian(batch_hm[:, :, cls_id], ct_int, radius) # Assign ground truth height and width batch_wh[ct_int[1], ct_int[0]] = 1. * w, 1. * h # Assign center point offset batch_offset[ct_int[1], ct_int[0]] = ct - ct_int # Set the corresponding mask to 1 batch_offset_mask[ct_int[1], ct_int[0]] = 1 return image, batch_hm, batch_wh, batch_offset, batch_offset_mask

Loss计算

Loss由三部分组成             ,分别使用交叉熵+focal loss的HeatMap损失                           ,论文中提到

α

\alpha

α设置为2               ,

β

\beta

β
设置为4

原作者的代码是没有对pred输出做限制         ,我在实际训练中如果不加以限制                          ,则会导致pred经过log计算之后的输出为NaN或Inf                    ,所以使用torch.clamp()进行截取     ,相关代码如下:

def focal_loss(pred, target): """ classifier loss of focal loss Args: pred: heatmap of prediction target: heatmap of ground truth Returns: cls loss """ # Find every image positive points and negative points, # one bounding box corresponds to one positive point, # except positive points, other feature points are negative sample. pos_inds = target.eq(1).float() neg_inds = target.lt(1).float() # The negative samples near the positive sample feature point have smaller weights neg_weights = torch.pow(1 - target, 4) loss = 0 pred = torch.clamp(pred, 1e-6, 1 - 1e-6) # Calculate Focal Loss. # The hard to classify sample weight is large, easy to classify sample weight is small. pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_inds * neg_weights # Loss normalization is carried out num_pos = pos_inds.float().sum() pos_loss = pos_loss.sum() neg_loss = neg_loss.sum() if num_pos == 0: loss = loss - neg_loss else: loss = loss - (pos_loss + neg_loss) / num_pos return loss

WidthHeight和Offset的损失由l1 loss计算                          ,原理比较简单                         ,代码注释中有详细说明,这里就不做阐述了               。

def l1_loss(pred, target, mask): """ Calculate l1 loss Args: pred: offset detection result target: offset ground truth mask: offset mask, only center point is 1, other place is 0 Returns: l1 loss """ expand_mask = torch.unsqueeze(mask, -1).repeat(1, 1, 1, 2) # Dont calculate loss in the position without ground truth. loss = F.l1_loss(pred * expand_mask, target * expand_mask, reduction=sum) loss = loss / (mask.sum() + 1e-7) return loss

为了防止宽高部分前几个epoch的误差较大                     ,影响了总的loss                             ,所以使用0.1的系数缩放wh_loss     ,公式如下:

t

o

t

a

l

_

l

o

s

s

=

c

_

l

o

s

s

+

0.1

w

h

_

l

o

s

s

+

o

f

f

_

l

o

s

s

total\_loss = c\_loss + 0.1 * wh\_loss + off\_loss

total_loss=c_loss+0.1wh_loss+off_loss

模型预测

模型预测相对训练会多两个步骤:1            、预测结果后处理 2          、预测框转换                            。

预测结果后处理

这个很好理解                 ,模型得到的是HeatMap                             、WidthHeight和Offset                            ,我们需要用将三个结果进行运算才能得出最终的预测框            。

首先          ,我们对HeatMap的通道做Argmax和max处理             ,得出分类的index和最高得分          。根据得分置信度过滤掉低于阈值的物体中心(此时的过滤完的结果已经带有分类信息和物体中心位置的坐标了)                             。

第二步                           ,将Offset的偏移量加到HeatMap中的物体中心坐标上               ,进行修正                。

第三步         ,根据上面HeatMap的过滤结果                          ,对置信度高于阈值的WidthHeight进行转换                    ,xyhw -> x1y1x2y2     ,就得到预测框了     。

最后将预测框结果进行归一化                          ,方便后面预测框转换计算                              。

预测框转换

虽然论文作者一直强调自己这个模型是一个完全端到端的设计                         ,不需要nms等后处理操作                     。只需要一个3x3的max_pooling层就可以替代nms。但是实际使用中,无论模型的预测结果还是训练数据                     ,都在结果转换后进行nms                         。

这里简单讲一下原因                          。以这张图作为例子                             ,里面有一只狗和凳子是检测的目标    。凳子和狗的原始HeatMap是以下两张图     ,我们可以看到                 ,中心区域响应最强                            ,周围慢慢衰减至0                    。

经过3x3的max_pooling之后          ,确实消除了一些低响应区域             ,但由于3x3的核太小                           ,只进行一次池化操作               ,无法消除所有底响应区域         ,结果如下图所示                           。这样的结果是不可用的                          ,画到原图上之后                    ,物体会有多个中心     ,且框的宽高都是0        。

如果硬是要用池化层进行过滤                          ,只有两个办法1                、加大卷积核尺寸 2     、增加池化次数               。这两个方法都会增加计算量                         ,而且对于每张图来说,尺寸设置多大?池化次数增加多少次?这都不一样                     ,没办法用一个统一的值来确定                            。所以最方便的方法还是用nms进行后处理            。

在进行完nms之后                             ,我们将预测框的坐标尺度从0-1变为原图大小     ,最后将之前图像resize和pad部分给去掉就得到最后的检测框了          。

上述代码为:

def predict(image, model, dev, args): """ Predict one image Args: image: input image model: CenterNet model dev: torch device args: ArgumentParser Returns: bounding box of one image(x1, y1, x2, y2 score, label). """ input_data = image_resize(image, (args.input_height, args.input_height)) input_data = preprocess_input(input_data) input_data = np.expand_dims(input_data, 0) input_data = torch.from_numpy(input_data.copy()).float() input_data = input_data.to(dev) hms, whs, offsets = model(input_data) hms = hms.permute(0, 2, 3, 1) whs = whs.permute(0, 2, 3, 1) offsets = offsets.permute(0, 2, 3, 1) outputs = postprocess_output(hms, whs, offsets, args.confidence, dev) outputs = decode_bbox(outputs, (args.input_height, args.input_height), dev, image_shape=image.shape[:2], remove_pad=True, need_nms=True, nms_thres=0.45) return outputs[0]

训练

代码中实现了两种训练方式                 ,从头开始                            ,和迁移学习+fine tune                             。推荐使用后者          ,会有更好的效果                。github上提供了一些自动化脚本             ,方便初学者更好上手     。

tensorboard

训练过程中可以用tensorboard来观察训练情况                              。内置有训练的loss和learning rate曲线                     。

在Images里能查看到模型的实时预测情况                           ,左图为Ground Truth               ,右图为Prediction。

!

开启方法为:

tensorboard --logdir="./logs/exp/"

可能会出现的现象

随着训练的进行         ,会出现train loss持续下降                          ,val loss先下降后上升的情况                         。这是CenterNet独有的假过拟合现象                          。这是由于网络对非物体中心的HeatMap预测趋近于0                    ,从而和Ground Truth不一致    。

Epoch 5 ->val loss = 0.2789 -> peak conf = 0.4273

Epoch 20 -> loss = 8.3402 -> peak conf = 0.9791

总结

CenterNet是anchor free中的一个里程碑之作                    。CenterNet除了目标检测之外     ,还可以迁移到其他领域中                          ,如人体关键点                         ,姿态预测等                           。推荐大家先读一下原文        。

本人用torch复现的代码在这里               。

部分图引用源为:睿智的目标检测46——Pytorch搭建自己的Centernet目标检测平台                              、从零开始理解CenterNet中的Heatmap热图

声明:本站所有文章,如无特殊说明或标注                     ,均为本站原创发布                            。任何个人或组织                             ,在未征得本站同意时     ,禁止复制                     、盗用、采集                         、发布本站内容到任何网站                          、书籍等各类媒体平台            。如若本站内容侵犯了原著者的合法权益                 ,可联系我们进行处理          。

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
wordpress建站插件(优秀的WordPress自动内链插件-轻松提升网站SEO)