天天看点

pytorch 从头开始faster-rcnn(三):vgg16 (带有网络冻结的写法)

由于pytorch自带了一些网络结构可以直接使用,所以直接调用有的模型,详细的可以查看官方文档:https://pytorch.org/docs/stable/torchvision/models.html

def decom_vgg16():
    if opt.caffe_pretrain:
        model = vgg16(pretrained=False)
        if not opt.load_path:
            model.load_state_dict(t.load(otp.caffe_pretrain_path))
            print('load caffe_pretrain')

    else:
        model = vgg16(not opt.load_path)

    feature = list(model.features)[:30]
    classify = model.classifier

# 保留分类器的几层
    classifier = list(classifier)
    del classfier[6]
    if not opt.use_drop:
        del classifier[5]
        del classifier[2]

    classifier = nn.Sequential(*classfier)

    # 冻结卷积网络前4层
    for layer in feature[:10]:
        for p in layer.parameters():
            p.requires_grad = False

    return nn.Sequential(*features), classifier
           

vgg16分为特征提取的卷积层和进行分类回归的分类层。将vgg16的前4层冻结,不进行反向传播更新权重。然后将分类层提取出来,如果不进行dropout的话就删除dropout层,并且删除了最后一层。

分类层的情况是如下:

[Linear(in_features=25088, out_features=4096, bias=True), ReLU(inplace), Dropout(p=0.5), Linear(in_features=4096, out_features=4096, bias=True), ReLU(inplace), Dropout(p=0.5), Linear(in_features=4096, out_features=1000, bias=True)]

vgg16的层次结构图像来自https://www.cs.toronto.edu/~frossard/post/vgg16/

pytorch 从头开始faster-rcnn(三):vgg16 (带有网络冻结的写法)

 最后,特征提取层作为faster-rcnn的起始特征提取,分类层作为RPN的分类层。RPN将在下一章节讲述。

继续阅读