最近需要初始化一個模型的參數,但在pytorch中沒有類似可直接使用的類似tf.truncnormal()函數,一開始是直接嘗試用torch.nn.init.normal_() 來代替tf.truncnormal()。效果相差較遠,簡單的正态分布并不能代替截斷正态分布的作用。故考慮自己實作,借鑒了 https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15的一個實作, 實作代碼如下:
實作中借用了scipy.stats下的truncnorm函數來産生截斷的正态分布值,再自行包裝成torch可用的tensor。而這樣的實作是有問題的,原因在于truncnorm.rvs()是基于标準正态分布産生截斷的正态分布,而在原模型中是使用standard deviation為0.1的正态分布來産生的,初次實作時自己忽略了這樣的細節(且rvs函數沒有發現可以調整正态分布方差的參數,故不可以繼續使用rvs函數來生成)。畫出了一開始生成的錯誤的截斷正态分布:
問題解決:參考了上面連結中另一個實作的截斷正态分布生成函數代碼如下:
def
對應的畫出了新的截斷正态分布生成函數和原tf中的函數的結果如下:
基本可以達到tensorflow中truncnormal相同的效果。