天天看點

pytorch使用記錄(二) 參數初始化

本文主要記錄如何在pytorch中對卷積層和批歸一層權重進行初始化,也就是weight和bias。

主要會用到torch的apply()函數。【apply】

apply(fn):将fn函數遞歸地應用到網絡模型的每個子模型中,主要用在參數的初始化。

使用apply()時,需要先定義一個參數初始化的函數。

def weight_init(m):
    classname = m.__class__.__name__ # 得到網絡層的名字,如ConvTranspose2d
    if classname.find('Conv') != -1:  # 使用了find函數,如果不存在傳回值為-1,是以讓其不等于-1
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
           

之後,定義自己的網絡,得到網絡模型,使用apply()函數,就可以分别對conv層和bn層進行參數初始化。

model = net()
model.apply(weight_init)
           

以上就可以對各層進行參數初始化。

如何檢視效果呢?也就是如何檢視各層的weight和bias。

需要用到上一篇文章【點此到達】中的state_dict()函數,傳回網絡的所有參數。具體如下:

params = model.state_dict()
    for k, v in params.items():
        print k  # 列印網絡中的變量名,找出自己想檢視的參數名字,這個與自己定義網絡時起的名字有關。
    print params['net.convt1.weight']   # 列印convt1的weight
    print params['net.convt1.bias'] # 列印convt1的bias
           

繼續閱讀