天天看点

pytorch checkpoint_[日常] PyTorch 预训练模型,保存,读取和更新模型参数以及多 GPU 训练模型

pytorch checkpoint_[日常] PyTorch 预训练模型,保存,读取和更新模型参数以及多 GPU 训练模型
本文用于记录如何进行 PyTorch 所提供的预训练模型应如何加载,所训练模型的参数应如何保存与读取,如何冻结模型部分参数以方便进行 fine-tuning 以及如何利用多 GPU 训练模型。

(各位收藏的时候, 麻烦顺手点个赞同吧)

目录

  1. PyTorch 预训练模型
  2. 保存模型参数
  3. 读取模型参数
  4. 冻结部分模型参数,进行 fine-tuning
  5. 模型训练与测试的设置
  6. 利用 torch.nn.DataParallel 进行多 GPU 训练

1. PyTorch 预训练模型

Pytorch 提供了许多 Pre-Trained Model on ImageNet,仅需调用 torchvision.models 即可,具体细节可查看官方文档。

往往我们需要对 Pre-Trained Model 进行相应的修改,以适应我们的任务。这种情况下,我们可以先输出 Pre-Trained Model 的结构,确定好对哪些层修改,或者添加哪些层,接着,再将其修改即可。

比如,我需要将 ResNet-50 的 Layer 3 后的所有层去掉,在分别连接十个分类器,分类器由 ResNet-50.layer4 和 AvgPool Layer 和 FC Layer 构成。这里就需要用到 torch.nn.ModuleList 了,比如:

self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

代码中的 [nn.Linear(10, 10) for i in range(10)] 是一个python列表,必须要把它转换成一个Module Llist列表才可以被 PyTorch 使用,否则在运行的时候会报错:

RuntimeError: Input type (CUDAFloatTensor) and weight type (CPUFloatTensor) should be the same

2. 保存模型参数

PyTorch 中保存模型的方式有许多种:

# 保存整个网络
torch.save(model, PATH) 
# 保存网络中的参数, 速度快,占空间少
torch.save(model.state_dict(),PATH)
# 选择保存网络中的一部分参数或者额外保存其余的参数
torch.save({'state_dict': model.state_dict(), 'fc_dict':model.fc.state_dict(),
            'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},
             PATH)
           

3. 读取模型参数

同样的,PyTorch 中读取模型参数的方式也有许多种:

# 读取整个网络
model = torch.load(PATH)

# 读取 Checkpoint 中的网络参数
model.load_state_dict(torch.load(PATH))

# 若 Checkpoint 中的网络参数与当前网络参数有部分不同,有以下两种方式进行加载:
# 1. 利用字典的 update 方法进行加载
Checkpoint = torch.load(Path)
model_dict = model.state_dict()
model_dict.update(Checkpoint)
model.load_state_dict(model_dict)
# 2. 利用 load_state_dict() 的 strict 参数进行部分加载
model.load_state_dict(torch.load(PATH), strict=False)
           

4. 冻结部分模型参数,进行 fine-tuning

加载完 Pre-Trained Model 后,我们需要对其进行 Finetune。但是在此之前,我们往往需要冻结一部分的模型参数:

# 第一种方式
for p in freeze.parameters(): # 将需要冻结的参数的 requires_grad 设置为 False
    p.requires_grad = False
for p in no_freeze.parameters(): # 将fine-tuning 的参数的 requires_grad 设置为 True
    p.requires_grad = True
# 将需要 fine-tuning 的参数放入optimizer 中
optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

# 第二种方式
optim_param = []
for p in freeze.parameters(): # 将需要冻结的参数的 requires_grad 设置为 False
    p.requires_grad = False
for p in no_freeze.parameters(): # 将fine-tuning 的参数的 requires_grad 设置为 True
    p.requires_grad = True
    optim_param.append(p)
optimizer.SGD(optim_param, lr=1e-3) # 将需要 fine-tuning 的参数放入optimizer 中
           

5. 模型训练与测试的设置

训练时,应调用 model.train() ;测试时,应调用 model.eval(),以及 with torch.no_grad():

model.train()

:使 model 变成训练模式,此时 dropout 和 batch normalization 的操作在训练起到防止网络过拟合的问题。

model.eval()

:PyTorch会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。不然的话,一旦测试集的 Batch Size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。

with torch.no_grad()

:PyTorch 将不再计算梯度,这将使得模型 forward 的时候,显存的需求大幅减少,速度大幅提高。

注意:若模型中具有 Batch Normalization 操作,想固定该操作进行训练时,需调用对应的 module 的 eval() 函数。这是因为 BN Module 除了参数以外,还会对输入的数据进行统计,若不调用 eval(),统计量将发生改变!具体代码可以这样写:

for module in model.modules():
    module.eval()
           

在其他地方看到的解释:

  • model.eval() will notify all your layers that you are in eval mode, that way, batchnorm or dropout layers will work in eval model instead of training mode.
  • torch.no_grad() impacts the autograd engine and deactivate it. It will reduce memory usage and speed up computations but you won’t be able to backprop (which you don’t want in an eval script).

6. 利用 torch.nn.DataParallel 进行多 GPU 训练

import torch
import torch.nn as nn
import torchvision.models as models
# 生成模型
# 利用 torch.nn.DataParallel 进行载入模型,默认使用所有GPU(可以用 CUDA_VISIBLE_DEVICES 设置所使用的 GPU)
model = nn.DataParallel(models.resnet18()) 

# 冻结参数
for param in model.module.layer4.parameters():
    param.requires_grad = False
param_optim = filter(lambda p:p.requires_grad, model.parameters())

# 设置测试模式
model.module.layer4.eval()

# 保存模型参数(读取所保存模型参数后,再进行并行化操作,否则无法利用之前的代码进行读取)
torch.save(model.module.state_dict(),'./CheckPoint.pkl')
           
参考资料:
  • TORCHVISION.MODELS
  • Pytorch 保存模型与加载模型
  • pytorch 模型部分参数的加载
  • pytorch 固定部分参数训练
  • pytorch 杂记(1)---Pytorch model.train 与 model.eval
  • Pytorch 容器
  • pytorch如何使用多块gpu?
如果你看到了这篇文章的最后,并且觉得有帮助的话,麻烦你花几秒钟时间点个赞,或者受累在评论中指出我的错误。谢谢! 作者信息: 知乎:没头脑 CSDN:Code_Mart Github:Tao Pu

继续阅读