首页IT科技yolov3目标检测步骤流程图(【目标检测】YOLOv5遇上知识蒸馏)

yolov3目标检测步骤流程图(【目标检测】YOLOv5遇上知识蒸馏)

时间2025-08-23 18:18:04分类IT科技浏览5333
导读:前言 模型压缩方法主要4种: 网络剪枝(Network pruning 稀疏表示(Sparse representation 模型量化(Model quantification 知识蒸馏(Konwledge disti...

前言

模型压缩方法主要4种:

网络剪枝(Network pruning) 稀疏表示(Sparse representation) 模型量化(Model quantification) 知识蒸馏(Konwledge distillation)

本文主要来研究知识蒸馏的相关知识               ,并尝试用知识蒸馏的方法对YOLOv5进行改进               。

知识蒸馏理论简介

概述

知识蒸馏(Knowledge Distillation)由深度学习三巨头Hinton在2015年提出                      。

论文标题:Distilling the knowledge in a neural network

论文地址:https://arxiv.org/pdf/1503.02531.pdf

“蒸馏               ”是个化工学科中的术语                      ,本身指的是将液体混合物加热沸腾       ,使其中沸点较低的组分首先变成蒸气               ,再冷凝成液体                      ,用来分离混合物       。而知识蒸馏的含义和蒸馏本身相似但并不完全相同       ,知识蒸馏指的是同时训练两个网络        ,一个较复杂的网络作为教师网络                      ,另一个较简单的网络作为学生网络               ,将教师网络训练得到的结果提炼出来        ,用来引导学生网络的结果                      ,从而让学生网络学习得更好               。

一个公认前提是小模型相比于大模型更容易陷入局部最优               ,下图[1]中,中间绿色的椭圆表示小网络模型的收敛空间                      ,红色的椭圆表示大网络模型的收敛空间;如果不用知识蒸馏                      ,直接训练小网络,它只会在绿色椭圆区域收敛               ,而使用知识蒸馏之后                      ,小网络可以收敛到橙色椭圆区域       ,收敛到更小的最优点                      。

软标签

有了上面的概念               ,自然而然想到的一个问题就是                      ,教师模型如何引导学生模型进行学习       。这就涉及到论文中提及的一个概念——软标签(Soft target)

如上图[1]所示       ,以手写数字识别为例        ,这是一个10分类任务                      ,左边这幅图是采用硬标签(Hard target)               ,输出独热向量        ,概率最高的类别为1                      ,其它类别为0;右边这幅图采用的是软标签(Soft target)               ,通过softmax层输出的各类别概率,这样的输出具有更高的信息熵                      ,即包含更多信息量        。

教师模型输出软标签                      ,从而指导学生模型学习                      。

softmax的原始公式是这样:

q

i

=

exp

(

z

i

)

j

exp

(

z

j

)

q_{i}=\frac{\exp \left(z_{i}\right)}{\sum_{j} \exp \left(z_{j}\right)}

qi=jexp(zj)exp(zi)

在论文中,作者对这个公式又加以改进               ,引入了一个新的温度变量T                      ,公式如下:

q

i

=

exp

(

z

i

/

T

)

j

exp

(

z

j

/

T

)

q_{i}=\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)}

qi=jexp(zj/T)exp(zi/T)

加入这个变量       ,能使各类别之间的输出更均衡               ,如下图[2]所示                      ,T=1为softmax       ,但是当T过大时        ,会发现输出向量会趋于一条直线                      ,因此               ,T通常取中间较小值               。

蒸馏温度

上面引入了一个新的变量温度T        ,这个T也可以称为蒸馏温度                      ,原论文中给出了关于T的进一步讨论               ,随着T的增加,信息熵会越来越大                      ,如下图[1]所示:

实际上                      ,温度的高低改变的是Student模型训练过程中对负标签的关注程度        。当温度较低时,对负标签的关注               ,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时                      ,负标签相关的值会相对增大       ,Student模型会相对更多地关注到负标签[1]                      。

因此               ,T的取值可以遵循如下策略:

当想从负标签中学到一些信息量的时候                      ,温度T应调高一些 当想减少负标签的干扰的时候       ,温度T应调低一些

需要注意的是        ,这个T只作用于教师网络和学生网络的蒸馏过程                      ,学生网络正常输出仍使用softmax               ,即T取值为1        ,就像蒸馏过程一样                      ,需要先进行升温               ,将知识蒸馏出来,然后输出的时候要冷却降温(T=1)

知识蒸馏过程

从原理上来讲                      ,知识蒸馏没有想象中那么复杂                      ,其流程如下图[1]所示:

在T下,训练教师网络得到 soft targets1 在T下               ,训练学生网络得到 soft targets2 通过 soft targets1 和 soft targets2 得到 distillation loss 在温度1下                      ,训练学生网络得到 soft targets3 通过 soft targets3 和 ground truth 得到 student loss

通过这五个步骤       ,就得到了两个损失值 distillation loss 和 student loss               ,那么训练的整体损失                      ,就是这两个损失值的加权和       ,公式[2]如下:

注:

这里的蒸馏损失系数乘了一个

T

2

T^2

T2
这是由于soft targets产生的梯度大小按照

1

/

T

2

1/T^2

1/T2
进行了缩放        ,这里需要补充回来

α

\alpha

α
应远小于

β

\beta

β
即需要让知识蒸馏损失权重大一些                      ,否则没有蒸馏效果

后面               ,论文作者分别做了手写数字识别和声音识别实验        ,这里主要来看作者在MNIST数据集上的实验结果                      ,结果如下表所示:

10xEnsemble是10个教师模型的平均值               ,Distilled Single model是Baseline模型经过蒸馏之后的结果,可以看到蒸馏出来的准确率提升了1.9%.

YOLOv5加上知识蒸馏

下面就将知识蒸馏融入到YOLOv5目标检测任务中                      ,使用的是YOLOv5-6.0版本               。

相关代码参考自:https://github.com/Adlik/yolov5

代码修改

其实知识蒸馏的想法很简单                      ,在仓库作者的代码版本中,修改的内容也并不多               ,主要是模型加载和损失计算部分。

下面按照顺序来解读一下修改内容                      。

首先是train_distillation.py这个文件                      ,通过修改train.py得到                      。

新增四个参数:

parser.add_argument(--t_weights, type=str, default=./weights/yolov5s.pt, help=initial teacher model weights path) parser.add_argument(--t_cfg, type=str, default=models/yolov5s.yaml, help=teacher model.yaml path) parser.add_argument(--d_output, action=store_true, default=False, help=if true, only distill outputs) parser.add_argument(--d_feature, action=store_true, default=False, help=if true, distill both feature and output layers)

t_weights

教师模型权重       ,和学生模型加载类似

t_cfg

教师模型配置               ,和学生模型配置类似

d_output

这个参数写在这里但不起作用                      ,应该是作者调试时用到的参数       ,默认是只蒸馏结果

d_feature

这个参数默认是关闭        ,如果开启                      ,蒸馏损失计算将不仅仅是计算两个模型输出的结果               ,并且中间特征层也会参与计算(不过这个作者没写完整        ,可能写到一半弃坑了)

模型加载:

这部分需要多加载一个教师模型                      ,相关代码如下: # Model check_suffix(weights, .pt) # check weights pretrained = weights.endswith(.pt) if pretrained: with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally ckpt = torch.load(weights, map_location=device) # load checkpoint model = Model(cfg or ckpt[model].yaml, ch=3, nc=nc, anchors=hyp.get(anchors)).to(device) # create exclude = [anchor] if (cfg or hyp.get(anchors)) and not resume else [] # exclude keys csd = ckpt[model].float().state_dict() # checkpoint state_dict as FP32 csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect model.load_state_dict(csd, strict=False) # load LOGGER.info(fTransferred {len(csd)}/{len(model.state_dict())} items from {weights}) # report # 这里添加加载教师模型 # Teacher model LOGGER.info(fLoaded teacher model {t_cfg}) # report t_ckpt = torch.load(t_weights, map_location=device) # load checkpoint t_model = Model(t_cfg or t_ckpt[model].yaml, ch=3, nc=nc, anchors=hyp.get(anchors)).to(device) exclude = [anchor] if (t_cfg or hyp.get(anchors)) and not resume else [] # exclude keys csd = t_ckpt[model].float().state_dict() # checkpoint state_dict as FP32 csd = intersect_dicts(csd, t_model.state_dict(), exclude=exclude) # intersect t_model.load_state_dict(csd, strict=False) # load

损失计算:

这里多了一个d_outputs_loss               ,也就是计算蒸馏损失 s_loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size d_outputs_loss = compute_distillation_output_loss(pred, t_pred, model, d_weight=10) loss = d_outputs_loss + s_loss

蒸馏损失在loss.py中进行定义:

def compute_distillation_output_loss(p, t_p, model, d_weight=1): t_ft = torch.cuda.FloatTensor if t_p[0].is_cuda else torch.Tensor t_lcls, t_lbox, t_lobj = t_ft([0]), t_ft([0]), t_ft([0]) h = model.hyp # hyperparameters red = mean # Loss reduction (sum or mean) if red != "mean": raise NotImplementedError("reduction must be mean in distillation mode!") DboxLoss = nn.MSELoss(reduction="none") DclsLoss = nn.MSELoss(reduction="none") DobjLoss = nn.MSELoss(reduction="none") # per output for i, pi in enumerate(p): # layer index, layer predictions t_pi = t_p[i] t_obj_scale = t_pi[..., 4].sigmoid() # BBox b_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, 4) t_lbox += torch.mean(DboxLoss(pi[..., :4], t_pi[..., :4]) * b_obj_scale) # Class if model.nc > 1: # cls loss (only if multiple classes) c_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, model.nc) # t_lcls += torch.mean(c_obj_scale * (pi[..., 5:] - t_pi[..., 5:]) ** 2) t_lcls += torch.mean(DclsLoss(pi[..., 5:], t_pi[..., 5:]) * c_obj_scale) # t_lobj += torch.mean(t_obj_scale * (pi[..., 4] - t_pi[..., 4]) ** 2) t_lobj += torch.mean(DobjLoss(pi[..., 4], t_pi[..., 4]) * t_obj_scale) t_lbox *= h[box] t_lobj *= h[obj] t_lcls *= h[cls] # bs = p[0].shape[0] # batch size loss = (t_lobj + t_lbox + t_lcls) * d_weight return loss

因为目标检测和原论文中的分类问题有所区别,并不能直接简单套用原论文提出的soft-target                      ,那么这里的处理方式就是将三个损失(位置损失               、目标损失                      、类别损失)简单粗暴地用MSELoss进行计算                      ,然后蒸馏损失就是这三部分之和。

值得注意的是,理论部分我们提到过               ,蒸馏损失需要比学生损失的权重更大                      ,因此       ,这里在计算蒸馏损失中               ,加入了一个权重d_weight                      ,权重计算时取10.

下面是代码作者给出的一个实验结果:

Model Compression

strategy Input size

[h, w] mAPval

0.5:0.95 Pretrain weight yolov5s baseline [640, 640] 37.2 pth | onnx yolov5s distillation [640, 640] 39.3 pth | onnx yolov5s quantization [640, 640] 36.5 xml | bin yolov5s distillation + quantization [640, 640] 38.6 xml | bin

他采用的是coco数据集       ,用yolov5m作为教师模型        ,yolov5s作为学生模型                      ,表格第二行展示了蒸馏之后的效果               ,mAP提升了2.1.

实验验证

为了验证蒸馏是否有效        ,我在VisDrone数据集上进行了实验                      ,训练了100epoch               ,实验结果如下表所示:

Student Model Teacher Model Input size

[h, w] mAPtest

0.5 mAPtest

0.5:0.95 yolov5m - [640, 640] 0.32 0.181 yolov5m yolov5m [640, 640] 0.305 0.163 yolov5m yolov5x [640, 640] 0.302 0.161 yolov5m - [1280, 1280] 0.448 0.261 yolov5m yolov5x [1280, 1280] 0.401 0.23

结果挺意外的,使用蒸馏训练之后                      ,mAP反而下降了                      ,严重怀疑蒸馏出来的是糟粕😵

结论

知识蒸馏理论上并不复杂,但经过实验               ,基本判断这玩意理论价值大于应用价值                      ,用来讲故事可以       ,实际上提升效果非常有限               。当然这是我做了有限实验得出的初步结论               ,如果读者有更好的思路                      ,可以在评论区留言和我讨论                      。

参考

[1]【论文泛读】 知识蒸馏:Distilling the knowledge in a neural network:https://www.bilibili.com/read/cv16841475

[2]【论文精讲|无废话版】知识蒸馏:https://www.bilibili.com/video/BV1h8411t7SA

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
快速启动栏显示桌面不见了怎么办(在快速启动中不再显示最近使用的文件) python web开发框架(Python中使用K-means算法)