1. 设置断点续训的目的
在遇到停电宕机,设备内存不足导致实验还没有跑完的情况下,如果没有使用断点续训,就需要从头开始训练,耗时费力。
断点续训主要保存的是网络模型的参数以及优化器optimizer的状态(因为很多情况下optimizer的状态会改变,比如学习率的变化)
2. 设置断点续训的方法
-
参数设置
resume: 是否进行续训
initepoch: 进行续训时的初始epoch
- checkpoint载入过程(这部分操作放在epoch循环前边)
resume = True # 设置是否需要从上次的状态继续训练
if resume:
if os.path.isfile("results/{}_model.pth".format(save_name_pre)):
print("Resume from checkpoint...")
checkpoint = torch.load("results/{}_model.pth".format(save_name_pre))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
initepoch = checkpoint['epoch'] + 1
print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch']))
else:
print("====>no checkpoint found.")
initepoch = 1 # 如果没进行训练过,初始训练epoch值为1
- 每一轮,checkpoint的存储过程,保存模型参数,优化器参数,轮数(这部分操作放在epoch循环里边)
# 保存断点
if test_acc_1 > best_acc:
best_acc = test_acc_1
checkpoint = {"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "results/{}_model.pth".format(save_name_pre)
torch.save(checkpoint, path_checkpoint)