天天看點

Error(s) pytorch 加載checkpoint state_dict出錯:Missing key(s) && Unexpected key(s) in state_dict

ERROR: 

Traceback (most recent call last):
File "test_0.py", line 130, in 
model = load_model()
File "test_0.py", line 104, in load_model
model.load_state_dict(checkpoint['state_dict'])
File "/home/cosmo/anaconda3/envs/tf8/lib/python3.6/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.3.weight", "module.features.3.bias", "module.features.6.weight", "module.features.6.bias", "module.features.8.weight", "module.features.8.bias", "module.features.10.weight", "module.features.10.bias", "module.classifier.weight", "module.classifier.bias".
Unexpected key(s) in state_dict: "features.module.0.weight", "features.module.0.bias", "features.module.3.weight", "features.module.3.bias", "features.module.6.weight", "features.module.6.bias", "features.module.8.weight", "features.module.8.bias", "features.module.10.weight", "features.module.10.bias", "classifier.weight", "classifier.bias".
           

REASON:

Error(s) pytorch 加載checkpoint state_dict出錯:Missing key(s) && Unexpected key(s) in state_dict

 The problem is the module is load with dataparallel activated and you are trying to load it without data parallel. That's why there's an extra module at the beginning of each key!

錯誤原因就是net.load_state_dict的時候,net的狀态不是處在gpu并行狀态,而存儲的net模型checkpoint是在gpu并行狀态下的!

SOLVE:

在net.load_state_dict前将net的狀态設定成gpu并行模式就好了。

Error(s) pytorch 加載checkpoint state_dict出錯:Missing key(s) && Unexpected key(s) in state_dict

繼續閱讀