網站首頁 編程語言 正文
前言
最近使用pytorch訓練模型,保存模型后再次加載使用出現了一些問題。記錄一下解決方案!
一、torch中模型保存和加載的方式
1、模型參數和模型結構保存和加載
torch.save(model,path)
torch.load(path)
2、只保存模型的參數和加載——這種方式比較安全,但是比較稍微麻煩一點點
torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)
二、torch中模型保存和加載出現的問題
1、單卡模型下保存模型結構和參數后加載出現的問題
模型保存的時候會把模型結構定義文件路徑記錄下來,加載的時候就會根據路徑解析它然后裝載參數;當把模型定義文件路徑修改以后,使用torch.load(path)就會報錯。
把model文件夾修改為models后,再加載就會報錯。
import torch
from model.TextRNN import TextRNN
load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)
這種保存完整模型結構和參數的方式,一定不要改動模型定義文件路徑。
2、多卡機器單卡訓練模型保存后在單卡機器上加載會報錯
在多卡機器上有多張顯卡0號開始,現在模型在n>=1上的顯卡訓練保存后,拷貝在單卡機器上加載
import torch
from model.TextRNN import TextRNN
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)
會出現cuda device不匹配的問題——你保存的模代碼段 小部件型是使用的cuda1,那么采用torch.load()打開的時候,會默認的去尋找cuda1,然后把模型加載到該設備上。這個時候可以直接使用map_location來解決,把模型加載到CPU上即可。
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))
3、多卡訓練模型保存模型結構和參數后加載出現的問題
當用多GPU同時訓練模型之后,不管是采用模型結構和參數一起保存還是單獨保存模型參數,然后在單卡下加載都會出現問題
a、模型結構和參數一起保然后在加載
torch.distributed.init_process_group(backend='nccl')
模型訓練的時候采用上述多進程的方式,所以你在加載的時候也要聲明,不然就會報錯。
b、單獨保存模型參數
model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)
同樣會出現問題,不過這里出現的問題是參數字典的key和模型定義的key不一樣
原因是多GPU訓練下,使用分布式訓練的時候會給模型進行一個包裝,代碼如下:
model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin')
print(model)
model.cuda(args.local_rank)
。。。。。。
model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True)
print('model',model)
包裝前的模型結構:
包裝后的模型
在外層多了DistributedDataParallel以及module,所以才會導致在單卡環境下加載模型權重的時候出現權重的keys不一致。
三、正確的保存模型和加載的方法
if gpu_count > 1:
torch.save(model.module.state_dict(),save_path)
else:
torch.save(model.state_dict(),save_path)
model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load(save_path)
model.load_state_dict(state_dict)
這樣就是比較好的范式,加載不會出錯。
總結
原文鏈接:https://blog.csdn.net/HUSTHY/article/details/115199280
相關推薦
- 2022-11-02 Android三方依賴沖突Gradle中exclude的使用_Android
- 2022-04-11 Python - logging.Formatter 的常用格式字符串
- 2022-05-11 React中的Refs屬性你來了解嗎_React
- 2022-12-26 C語言逆向分析語法超詳細分析_C 語言
- 2022-05-01 oracle刪除超過N天數據腳本的方法_oracle
- 2021-12-13 linux系統AutoFs自動掛載服務安裝配置_Linux
- 2022-08-20 C++詳細講解內存管理工具primitives_C 語言
- 2022-11-13 C++實現RSA加密解密算法是示例代碼_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支