论文: Cascade R-CNN Delving into High Quality Object Detection (CVPR 2018)
代码:cascade-rcnn_Pytorch
文章目录
-
- 为什么级联
-
-
- 总结
-
- 代码梳理
- 实验
- 参考文献
为什么级联
双阶段网络的典型代表就是Faster RCNN了,先通过RPN网络产生Proposals,然后挑选出正负样本,并分配标签进行训练。训练时选择哪些proposals作为正负样本,一般是根据IoU阈值来界定(一般取0.5),和实际目标框高于该阈值就作为正样本。
- 图(a):阈值设置得低,则正样本中含较多背景,使得误检较多
- 图(b):阈值设置得过高,虽然能减少误检,正样本的数量变少,容易过拟合,会漏检
- 图(c):当输入的IoU分布与阈值较为接近的时候,其输出IoU也相对较高,说明将阈值调节到输入IoU附近时,得到的输出一般有更好的表现
- 图(d):随着阈值增加,检测器的性能会大致在各自对应IoU区间有一个更好的表现,总体上阈值为0.6时表现更好
从上图(a)还有可以看出很重要的一点就是:大部分时候,曲线都在灰色曲线上方,说明输出IoU一般比输入IoU要高,将输出继续作为输入,相当于调高了Proposals的IoU(下一轮的输入IoU更高),同时适当调高IoU阈值,以得到更高IoU的输出,其结构示意图如下图(d)所示,这便是CascadeRCNN的级联结构
从下图可以看出,每一个阶段的输出的IoU分布明显不同,越到后面输出的质量越高,相当于一个进化的过程,越好的Proposals训练效果越好
总结
- 输入IoU的分布在阈值附近时,训练效果相对较好(可以根据输入IoU分布调整IoU阈值)
- 输出IoU一般比输入IoU高(可以级联)
代码梳理
相比FasterRCNN,其实就是将后面的FastRCNN部分重复
- stage1的RoIs1由RPN网络产生,stage1输出的预测框继续在前面的feature map上提取相应的RoIs2,作为stage2的输入,stage3以此类推
##################stage1##################
self.RCNN_top = nn.Sequential(
nn.Conv2d(256, 1024, kernel_size=cfg.POOLING_SIZE, stride=cfg.POOLING_SIZE, padding=0),
nn.ReLU(True),
nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
nn.ReLU(True)
)
self.RCNN_cls_score = nn.Linear(1024, self.n_classes)
self.RCNN_bbox_pred = nn.Linear(1024, 4 * self.n_classes)
##################stage2##################
self.RCNN_top_2nd = nn.Sequential(
nn.Conv2d(256, 1024, kernel_size=cfg.POOLING_SIZE, stride=cfg.POOLING_SIZE, padding=0),
nn.ReLU(True),
nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
nn.ReLU(True)
)
self.RCNN_cls_score_2nd = nn.Linear(1024, self.n_classes)
self.RCNN_bbox_pred_2nd = nn.Linear(1024, 4 * self.n_classes)
##################stage3##################
self.RCNN_top_3rd = nn.Sequential(
nn.Conv2d(256, 1024, kernel_size=cfg.POOLING_SIZE, stride=cfg.POOLING_SIZE, padding=0),
nn.ReLU(True),
nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
nn.ReLU(True)
)
self.RCNN_cls_score_3rd = nn.Linear(1024, self.n_classes)
self.RCNN_bbox_pred_3rd = nn.Linear(1024, 4 * self.n_classes)
实验
参考文献
【1】Cascade RCNN算法笔记
【2】Cascade R-CNN 详细解读
【3】cascade-rcnn_Pytorch