torch中的各种批归一的注意事项,不间断更新20190122
含有batchnorm的网络其train和eval时效果差距大
亦可参考笔者的另一篇博文:Pytorch 深度学习 模型训练 断点继续训练时损失函数恶化或与断点差异较大
- 和是否zero_grad及其位置关系不大,因为这个错了,train是多半不收敛的。
- 主要是因为BN的输入随着训练的进行是时变的,非稳态的,除非训练完全收敛,且学习率很小,并进行了多个batch的训练,此时的running mean 和running var才会收敛到正确的值。
- 如果BN的动量为0.1, 那么需要多训练的batch数我认为至少是20,即0.9**20=0.1214,也就是说20个batch前的训练数据在running mean和var中所占比重约十分之一。
- 建议:当需要用eval运作网络时,最好先以train模式进行多个batch的前向传播,用于稳定running mean和var。
torch.nn.BatchNorm2d
- 输入4D的矩阵,NxCxHxW
- C维度取Ci时可计算得到MEANi和VERi,分别是改通道对应的均值和方差
- 可见该批归一化过程是通道间独立的。
- 所以,如果batch中N=1也是可以正常运作的,这点区别于最早的批归一文章。