天天看點

pytorch自帶網絡_pytorch的餘弦退火學習率

點選上方“機器學習與生成對抗網絡”,關注"星标"

擷取有趣、好玩的前沿幹貨!

作者:limzero

位址:https://www.zhihu.com/people/lim0-34

編輯:人工智能前沿講習

最近深入了解了下pytorch下面餘弦退火學習率的使用.網絡上大部分教程都是翻譯的pytorch官方文檔,并未給出一個很詳細的介紹,由于官方文檔也隻是給了一個數學公式,對參數雖然有解釋,但是解釋得不夠明了,這樣一來導緻我們在調參過程中不能合理的根據自己的資料設定合适的參數.這裡作一個筆記,并且給出一些定性和定量的解釋和結論.說到pytorch自帶的餘弦學習率調整方法,通常指下面這兩個

pytorch自帶網絡_pytorch的餘弦退火學習率

CosineAnnealingLR

pytorch自帶網絡_pytorch的餘弦退火學習率

CosineAnnealingWarmRestarts

CosineAnnealingLR

這個比較簡單,隻對其中的最關鍵的Tmax參數作一個說明,這個可以了解為餘弦函數的半周期.如果max_epoch=50次,那麼設定T_max=5則會讓學習率餘弦周期性變化5次.

pytorch自帶網絡_pytorch的餘弦退火學習率

max_opoch=50, T_max=5

CosineAnnealingWarmRestarts

這個最主要的參數有兩個:

  • T_0:學習率第一次回到初始值的epoch位置
  • T_mult:這個控制了學習率變化的速度
    • 如果T_mult=1,則學習率在T_0,2T_0,3T_0,....,i*T_0,....處回到最大值(初始學習率)
      • 5,10,15,20,25,.......處回到最大值
    • 如果T_mult>1,則學習率在T_0,(1+T_mult)T_0,(1+T_mult+T_mult**2)T_0,.....,(1+T_mult+T_mult2+...+T_0i)*T0,處回到最大值
      • 5,15,35,75,155,.......處回到最大值
pytorch自帶網絡_pytorch的餘弦退火學習率

T_0=5, T_mult=1

pytorch自帶網絡_pytorch的餘弦退火學習率

T_0=5, T_mult=2

是以可以看到,在調節參數的時候,一定要根據自己總的epoch合理的設定參數,不然很可能達不到預期的效果,經過我自己的試驗發現,如果是用那種等間隔的退火政策(CosineAnnealingLR和Tmult=1的CosineAnnealingWarmRestarts),驗證準确率總是會在學習率的最低點達到一個很好的效果,而随着學習率回升,驗證精度會有所下降.是以為了能最終得到一個更好的收斂點,設定T_mult>1是很有必要的,這樣到了訓練後期,學習率不會再有一個回升的過程,而且一直下降直到訓練結束。

下面是使用示例和畫圖的代碼:

最後,對 scheduler.step(epoch + batch / iters)的一個說明,這裡的個人了解:一個epoch結束後再.step, 那麼一個epoch内所有batch使用的都是同一個學習率,為了使得不同batch也使用不同的學習率 ,則可以在這裡進行.step(将離散連續化,或者說使得采樣得更加的密集),下圖是以20個epoch,每個epoch5個batch,T0=2,Tmul=2畫的學習率變化圖

pytorch自帶網絡_pytorch的餘弦退火學習率

代碼:

import torchfrom torch.optim.lr_scheduler import CosineAnnealingLR,CosineAnnealingWarmRestarts,StepLRimport torch.nn as nnfrom torchvision.models import resnet18import matplotlib.pyplot as plt#model=resnet18(pretrained=False)optimizer = torch.optim.SGD(model.parameters(), lr=0.1)mode='cosineAnnWarm'if mode=='cosineAnn':    scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0)elif mode=='cosineAnnWarm':    scheduler = CosineAnnealingWarmRestarts(optimizer,T_0=2,T_mult=2)    '''    以T_0=5, T_mult=1為例:    T_0:學習率第一次回到初始值的epoch位置.    T_mult:這個控制了學習率回升的速度        - 如果T_mult=1,則學習率在T_0,2*T_0,3*T_0,....,i*T_0,....處回到最大值(初始學習率)            - 5,10,15,20,25,.......處回到最大值        - 如果T_mult>1,則學習率在T_0,(1+T_mult)*T_0,(1+T_mult+T_mult**2)*T_0,.....,(1+T_mult+T_mult**2+...+T_0**i)*T0,處回到最大值            - 5,15,35,75,155,.......處回到最大值    example:        T_0=5, T_mult=1    '''plt.figure()max_epoch=20iters=5cur_lr_list = []for epoch in range(max_epoch):    print('epoch_{}'.format(epoch))    for batch in range(iters):        scheduler.step(epoch + batch / iters)        optimizer.step()        #scheduler.step()        cur_lr=optimizer.param_groups[-1]['lr']        cur_lr_list.append(cur_lr)        print('cur_lr:',cur_lr)    print('epoch_{}_end'.format(epoch))x_list = list(range(len(cur_lr_list)))plt.plot(x_list, cur_lr_list)plt.show()
           

本文目的在于學術交流,并不代表本公衆号贊同其觀點或對其内容真實性負責,版權歸原作者所有,如有侵權請告知删除。

猜您喜歡:

附下載下傳 | 《Python進階》中文版

附下載下傳 | 經典《Think Python》中文版

附下載下傳 | 《Pytorch模型訓練實用教程》

附下載下傳 | 最新2020李沐《動手學深度學習》

附下載下傳 | 《可解釋的機器學習》中文版

附下載下傳 |《TensorFlow 2.0 深度學習算法實戰》

附下載下傳 | 超100篇!CVPR 2020最全GAN論文梳理彙總!

附下載下傳 |《計算機視覺中的數學方法》分享

pytorch自帶網絡_pytorch的餘弦退火學習率

繼續閱讀