天天看点

pytorch学习(一)pytorch中的断点续训

1. 设置断点续训的目的

在遇到停电宕机,设备内存不足导致实验还没有跑完的情况下,如果没有使用断点续训,就需要从头开始训练,耗时费力。

断点续训主要保存的是网络模型的参数以及优化器optimizer的状态(因为很多情况下optimizer的状态会改变,比如学习率的变化)

2. 设置断点续训的方法

  1. 参数设置

    resume: 是否进行续训

    initepoch: 进行续训时的初始epoch

  2. checkpoint载入过程(这部分操作放在epoch循环前边)
resume = True      # 设置是否需要从上次的状态继续训练
    if resume:
        if os.path.isfile("results/{}_model.pth".format(save_name_pre)):
            print("Resume from checkpoint...")
            checkpoint = torch.load("results/{}_model.pth".format(save_name_pre))
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            initepoch = checkpoint['epoch'] + 1
            print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch']))
        else:
            print("====>no checkpoint found.")
            initepoch = 1   # 如果没进行训练过,初始训练epoch值为1
           
  1. 每一轮,checkpoint的存储过程,保存模型参数,优化器参数,轮数(这部分操作放在epoch循环里边)
# 保存断点
        if test_acc_1 > best_acc:
            best_acc = test_acc_1
            checkpoint = {"model_state_dict": model.state_dict(),
                          "optimizer_state_dict": optimizer.state_dict(),
                          "epoch": epoch}
            path_checkpoint = "results/{}_model.pth".format(save_name_pre)
            torch.save(checkpoint, path_checkpoint)
           

继续阅读