網站首頁 編程語言 正文
今天跑一個模型的時候,需要加載部分預訓練模型的參數,這期間遇到使用torch.load 忽略了 map_location參數 默認gpu,這導致這個變量分配的顯存 不釋放 然后占用大量資源 gpu資源不能很好的利用。
問題講解:
比如我們一般我們會使用下面方式進行加載預訓練參數 到 自身寫的模型中:
from transformers import RobertaForMultipleChoice
import torch
model = RobertaForMultipleChoice.from_pretrained("roberta-large")
pretrained_model = torch.load("./checkpoints/txt_matching_e1.pth").roberta
pretrained_dict = pretrained_model.state_dict()
model_dict = model.roberta.state_dict()
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} #去除一些不需要的參數
model_dict.update(pretrained_dict)
model.roberta.load_state_dict(model_dict)
1. 當我們沒有使用參數時候 load 默認使用了一塊顯卡然后報錯
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 3; 10.76 GiB total capacity; 350.54 MiB already allocated; 21.81 MiB free; 356.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
torch load 之前gpu使用
torch load 之后 outof memory 了 并且也不釋放
2. 當我們沒有使用參數時候 load 默認使用了一塊顯卡然后報錯
當我試試指定顯卡 gpu會使用2841
pretrained_model = torch.load(“./checkpoints/txt_matching_e1.pth”,map_location=‘cuda:0’).roberta!
(model 直接cuda 的gpu 占用情況)然后把這里面參數給model,并且model也是用cuda0 然后gpu使用4193
你可能會想model是不是model load預訓練參數之后 就這么大了 那么load 和 load 參數后model 用不同gpu看看。
原理:cuda的內存管理機制
參考解釋博客:Pytorch訓練模型時如何釋放GPU顯存
解決方案:
1. 不占用顯存的使用方法,使用cpu 然后在del 用gc釋放內存
model = RobertaForMultipleChoice.from_pretrained("roberta-large")
pretrained_model = torch.load("./checkpoints/txt_matching_e1.pth",map_location='cpu').roberta
pretrained_dict = pretrained_model.state_dict()
model_dict = model.roberta.state_dict()
model_dict.update(pretrained_dict)
model.roberta.load_state_dict(model_dict)
del pretrained_model
import gc
gc.collect()
2. 合理使用, torch.cuda.empty_cache()
這個需要了解一下python的內存管理,引用機制。
比如我pretrain_model 給model直接加載參數,model和pretrain_model 都在cuda:0上,使用torch.cuda.empty_cache() 不能釋放pretrain_model 的顯存。
當 我把model 放到 cuda:1上(本來在cuda:0),這時候用torch.cuda.empty_cache() 可以釋放。
原文鏈接:https://blog.csdn.net/Miranda_ymz/article/details/127577639
相關推薦
- 2022-05-02 Pyinstaller+Pipenv打包Python文件的實現示例_python
- 2022-10-05 Python實現打印彩色字符串的方法詳解_python
- 2022-12-05 Python?如何截取字符函數_python
- 2022-07-10 vb腳本實現電腦定時關機操作
- 2022-08-15 常見哈希算法、Hmac算法和BouncyCastle
- 2023-02-18 Flow轉LiveData數據丟失原理詳解_Android
- 2022-10-29 微服務啟動報錯:No Feign Client for loadBalancing defined.
- 2022-12-10 C語言中如何實現桶排序_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支