網站首頁 編程語言 正文
參考
TORCH.LOAD
torch.load()
函數格式為:torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
,一般我們使用的時候,基本只使用前兩個參數。
模型的保存
模型保存有兩種形式,一種是保存模型的state_dict()
,只是保存模型的參數。那么加載時需要先創建一個模型的實例model
,之后通過torch.load()
將保存的模型參數加載進來,得到dict
,再通過model.load_state_dict(dict)
將模型的參數更新。
另一種是將整個模型保存下來,之后加載的時候只需要通過torch.load()
將模型加載,即可返回一個加載好的模型。
具體可參考:PyTorch模型的保存與加載。
模型加載中的map_location參數
具體來說,map_location
參數是用于重定向,比如此前模型的參數是在cpu
中的,我們希望將其加載到cuda:0
中。或者我們有多張卡,那么我們就可以將卡1中訓練好的模型加載到卡2中,這在數據并行的分布式深度學習中可能會用到。
首先定義一個AlexNet,并使用cuda:0
將其訓練了一個貓狗分類,之后把模型存儲起來。
map_location=None
我們先把state_dict
加載進來。
model_path = "./cuda_model.pth" model = torch.load(model_path) print(next(model.parameters()).device)
結果為:
cuda:0
因為保存的時候就是模型就是cuda:0
的,所以加載進來也是。
map_location=torch.device()
model_path = "./cuda_model.pth" model = torch.load(model_path, map_location=torch.device('cpu')) print(next(model.parameters()).device)
結果為:
cpu
模型從cuda:0
變成了cpu
。
map_location={xx:xx}
model_path = "./cuda_model.pth" model = torch.load(model_path, map_location={'cuda:0':'cuda:1'}) print(next(model.parameters()).device)
結果為:
cuda:1
模型從cuda:0
變成了cuda:1
。
model_path = "./cuda_model.pth" model = torch.load(model_path, map_location={'cuda:2':'cpu'}) print(next(model.parameters()).device)
結果為:
cuda:0
模型還是cuda:0
,并沒有變成cpu
。因為這個map_location
的映射是不對的,原始的模型就是cuda:0
,而映射是cuda:2
到cpu
,是不對的。這種情況下,map_location
返回None
,也就是和不加map_location
相同。
總結
原文鏈接:https://blog.csdn.net/qq_43219379/article/details/123675375
相關推薦
- 2022-05-08 Python?matplotlib實現折線圖的繪制_python
- 2022-06-26 ASP.NET?Core構建OData查詢Restful?API_實用技巧
- 2023-03-13 Android自定義Toast樣式實現方法詳解_Android
- 2022-10-28 Python利用Rows快速操作csv文件_python
- 2022-07-10 elementUI去掉el-card內部padding
- 2022-02-27 一個多模塊的Spring Boot項目打成多個jar包在服務器上運行
- 2022-11-01 Python正則表達中re模塊的使用_python
- 2022-02-01 uniapp 開發h5 優化加載速度
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支