網站首頁 編程語言 正文
如何在pytorch中指定CPU和GPU進行訓練,以及cpu和gpu之間切換
由CPU切換到GPU,要修改的幾個地方:
網絡模型、損失函數、數據(輸入,標注)
# 創建網絡模型
tudui = Tudui()
if torch.cuda.is_available():
tudui = tudui.cuda()
# 損失函數
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():
loss_fn = loss_fn.cuda()
# 數據輸入 包括訓練和測試的代碼,二者都需要添加此代碼
if torch.cuda.is_available():
imgs = imgs.cuda()
targets = targets.cuda()
方法一:.to(device)
1.不知道電腦GPU可不可用時:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
a.to(device)
第一行代碼的意思是判斷電腦GPU可不可用,如果可用的話device就采用cuda()即調用GPU,不可用的話就采用cpu()即調用CPU。
第二行代碼的意思就是把變量放到對應的device上(當然如果你用的是CPU的話就不用這一步了,因為變量默認是存在CPU上的,調用GPU的話要先把變量放到GPU上跑,跑完之后再調回CPU上)
2.指定GPU時
# 定義訓練的設備
device = torch.device("cuda:0")
# 網絡模型創建
tudui = Tudui()
tudui = tudui.to(device)
# 損失函數
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
# 訓練步驟開始
tudui.train()
for data in train_dataloader:
imgs, targets=data
imgs = imgs.to(device)
targets = targets.to(device)
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
# 測試步驟開始
tudui.eval()
total_test_loss = 0
total_accuracy = 0
with torch.no_grad():
for data in test_dataloader:
imgs, targets=data
imgs = imgs.to(device)
targets = targets.to(device)
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
total_test_loss = total_test_loss + loss.item()
accuracy = (outputs.argmax(1)==targets).sum()
total_accuracy = total_accuracy + accuracy
3.指定cpu時:
device = torch.device('cpu')
方法二:
1、需要修改的
# 三種常見的寫法
device = torch.device('cuda')
device = torch.device('cuda: 0')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2、代碼
# 創建模型
tudui = Tudui()
if torch.cuda.is_available():
tudui = tudui.cuda()
# 損失函數
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():
loss_fn = loss_fn.cuda()
# 訓練步驟開始
tudui.train()
for data in train_dataloader:
imgs, targets=data
if torch.cuda.is_available():
imgs = imgs.cuda()
targets = targets.cuda()
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
# 測試步驟開始
tudui.eval()
total_test_loss = 0
total_accuracy = 0
with torch.no_grad():
for data in test_dataloader:
imgs, targets=data
if torch.cuda.is_available():
imgs = imgs.cuda()
targets = targets.cuda()
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
total_test_loss = total_test_loss + loss.item()
accuracy = (outputs.argmax(1)==targets).sum()
total_accuracy = total_accuracy + accuracy
總結:
推薦方法一,如果自己電腦是只有CPU,可以推薦使用云端服務器,比如PaddlePaddle,Google colab,這些服務器由每周免費八個小時的使用時間,可供我們基本的需求。
原文鏈接:https://blog.csdn.net/mxh3600/article/details/124460988
相關推薦
- 2023-04-03 c++?lambda捕獲this?導致多線程下類釋放后還在使用的錯誤問題_C 語言
- 2023-05-05 Golang實現簡易的命令行功能_Golang
- 2022-08-08 pd.to_datetime中時間object轉換datetime實例_python
- 2023-03-28 python?label與one-hot之間的互相轉換方式_python
- 2022-09-06 Python中閉包與lambda的作用域解析_python
- 2022-10-01 Go語言異步API設計的扇入扇出模式詳解_Golang
- 2023-05-20 openGauss數據庫共享存儲特性概述_數據庫其它
- 2022-05-22 使用Supervisor守護ASP.NET?Core應用程序進程_實用技巧
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支