網站首頁 編程語言 正文
模型保存和加載
保存模型parameter、buffer
path = '/kaggle/working/state_dict_model.pt'
# 沒保存優化器的參數
# 保存所有參數和buffer量
# parameter和buffer https://blog.csdn.net/m0_61899108/article/details/124481684
torch.save(model.state_dict(), path)
n1_model = Model() # 實例化模型類
n1_model.load_state_dict(torch.load(path)) # 參數賦給新模型
n1_model.eval() # 將內部training設為False,不再記錄參數梯度值,運行效率高
'''
Model(
(fc): Linear(in_features=768, out_features=2, bias=True)
)
'''
保存整個模型
path = '/kaggle/working/entire_model.pt'
# 保存整個模型
torch.save(model, path)
n2_model = torch.load(path)
n2_model.eval()
'''
Model(
(fc): Linear(in_features=768, out_features=2, bias=True)
)
'''
checkpoint 保存和加載
epoch = 5
loss = 0.4
path = '/kaggle/working/5_0.4_checkpoint.pt'
torch.save({
'epoch': epoch
,'loss': loss
,'model_state_dict': model.state_dict()
,'optimizer_state_dict': optimizer.state_dict()
,
}, path)
# 加載
n3_model = Model()
n3_optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
checkpoint = torch.load(path)
epoch = checkpoint['epoch']
loss = checkpoint['loss']
n3_model.load_state_dict(checkpoint['model_state_dict'])
n3_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
n3_model.eval()
# - or -
n3_model.train()
import os
for dirname, _, filenames in os.walk('/kaggle/'):
for filename in filenames:
print(os.path.join(dirname, filename))
'''
/kaggle/lib/kaggle/gcp.py
/kaggle/input/chnsenticorp/ChnSentiCorp/dataset_info.json
/kaggle/input/chnsenticorp/ChnSentiCorp/ChnSentiCorp.py
/kaggle/input/chnsenticorp/ChnSentiCorp/chn_senti_corp-train.arrow
/kaggle/input/chnsenticorp/ChnSentiCorp/chn_senti_corp-test.arrow
/kaggle/input/chnsenticorp/ChnSentiCorp/chn_senti_corp-validation.arrow
/kaggle/working/state_dict_model.pt
/kaggle/working/__notebook_source__.ipynb
/kaggle/working/5_0.4_checkpoint.pt
/kaggle/working/entire_model.pt
'''
原文鏈接:https://blog.csdn.net/qq_45249685/article/details/127287361
相關推薦
- 2023-10-17 常用的utlis封裝
- 2022-07-27 PostgreSQL出現死鎖該如何解決_PostgreSQL
- 2023-01-10 golang實現簡單的tcp數據傳輸_Golang
- 2022-12-05 go?code?review?代碼調試_Golang
- 2022-07-04 python如何生成任意n階的三對角矩陣_python
- 2022-07-21 windows與Linux查看端口占用并終止端口占用
- 2022-09-22 Python 閉包與裝飾器
- 2023-01-29 Python?find()、rfind()方法及作用_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支