網站首頁 編程語言 正文
PyTorch加載模型model.load_state_dict()問題
希望將訓練好的模型加載到新的網絡上。
如上面題目所描述的,PyTorch在加載之前保存的模型參數的時候,遇到了問題。
Unexpected key(s) in state_dict: "module.features. ...".,Expected ".features....". 直接原因是key值名字不對應。
表明了加載過程中,期望獲得的key值為feature...,而不是module.features....。
這是由模型保存過程中導致的,模型應該是在DataParallel模式下面,也就是采用了多GPU訓練模型,然后直接保存的。
You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without . You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back.
解決上面的問題有三個辦法:?
1. 對load的模型創建新的字典
去掉不需要的key值"module".
# original saved file with DataParallel
state_dict = torch.load('checkpoint.pt') # 模型可以保存為pth文件,也可以為pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`,表面從第7個key值字符取到最后一個字符,正好去掉了module.
new_state_dict[name] = v #新字典的key值對應的value為一一對應的值。
# load params
model.load_state_dict(new_state_dict) # 從新加載這個模型。
2. 直接用空白''代替'module.'
model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()})
# 相當于用''代替'module.'。
#直接使得需要的鍵名等于期望的鍵名。
3. 最簡單的方法
加載模型之后,接著將模型DataParallel,此時就可以load_state_dict。
如果有多個GPU,將模型并行化,用DataParallel來操作。
這個過程會將key值加一個"module. ***"。
model = VGGNet()
params=model.state_dict() #獲得模型的原始狀態以及參數。
for k,v in params.items():
print(k) #只打印key值,不打印具體參數。
4. 總結
從出錯顯示的問題就可以看出,key值不匹配,因此可以選擇多種方法,將模型參數加載進去。
這個方法通常會在load_state_dict過程中遇到。將訓練好的一個網絡參數,移植到另外一個網絡上面,繼續訓練。
或者將訓練好的網絡checkpoint加載進模型,再次進行訓練??梢源蛴〕鰉odel state_dict來看出兩者的差別。
model = VGGNet()
params=model.state_dict() #獲得模型的原始狀態以及參數。
for k,v in params.items():
print(k) #只打印key值,不打印具體參數。
features.0.0.weight? ?
features.0.1.weight
features.1.conv.3.weight
features.1.conv.4.num_batches_tracked
model = VGGNet()
checkpoint = torch.load('checkpoint.pt', map_location='cpu')
# Load weights to resume from checkpoint。
# print('**************************************')
# 這個方法能夠直接打印出你保存的checkpoint的鍵和值。
for k,v in checkpoint.items():
print(k)
print("*****************************************")
輸出結果為:
module.features.0.0.weight",
"module.features.0.1.weight",
"module.features.0.1.bias
可以看出不匹配,模型的參數中,key值不同,多了module。
PS: 追加
在移植參數的過程中,對于出現?.total_ops和.total_params結尾的參數,可參考以下代碼:
from collections import OrderedDict
checkpoint = torch.load(
pretrained_model_file_path,
map_location=(None if use_cuda and not remap_to_cpu else "cpu"))
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
if not k.endswith('total_ops') and not k.endswith('total_params'):
name = k[7:]
new_state_dict[name] = v
最后
原文鏈接:https://blog.csdn.net/qq_32998593/article/details/89343507
相關推薦
- 2023-10-14 uniapp在Android 10對公共目錄的非媒體文件讀取上傳失敗問題
- 2022-04-06 用Python實現一個簡單的用戶系統_python
- 2022-09-24 關于R語言包的升級與降級問題_R語言
- 2022-10-18 AJAX請求以及解決跨域問題詳解_AJAX相關
- 2022-03-15 .Net?Core?SDK命令介紹及使用_自學過程
- 2022-12-13 sql索引失效的情況以及超詳細解決方法_MsSql
- 2023-03-28 python方法如何實現字符串反轉_python
- 2022-07-12 springboot整合jasypt加密yml配置文件
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支