網站首頁 編程語言 正文
pytorch中.numpy()、.item()、.cpu()、.detach()以及.data的使用方法_python
作者:Unstoppable~~~ ? 更新時間: 2022-10-16 編程語言.numpy()
Tensor.numpy()
將Tensor轉化為ndarray,這里的Tensor可以是標量或者向量(與item()不同)轉換前后的dtype不會改變
a = torch.tensor([[1.,2.]]) a_numpy = a.numpy() #[[1., 2.]]
.item()
將一個Tensor變量轉換為python標量(int float等)常用于用于深度學習訓練時,將loss值轉換為標量并加,以及進行分類任務,計算準確值值時需要
optimizer.zero_grad() outputs = model(data) loss = F.cross_entropy(outputs, label) #計算這一個batch的準確率 acc = (outputs.argmax(dim=1) == label).sum().cpu().item() / len(labels) #這里也用到了.item() loss.backward() optimizer.step() train_loss += loss.item() #這里用到了.item() train_acc += acc
.cpu()
將數據的處理設備從其他設備(如.cuda()拿到cpu上),不會改變變量類型,轉換后仍然是Tensor變量。
.detach()和.data(重點)
.detach()就是返回一個新的tensor,并且這個tensor是從當前的計算圖中分離出來的。但是返回的tensor和原來的tensor是共享內存空間的。
舉個例子來說明一下detach有什么用。 如果A網絡的輸出被喂給B網絡作為輸入, 如果我們希望在梯度反傳的時候只更新B中參數的值,而不更新A中的參數值,這時候就可以使用detach()
a = A(input) a = a.deatch() # 或者a.detach_()進行in_place操作 out = B(a) loss = criterion(out, labels) loss.backward()
Tensor.data和Tensor.detach()一樣, 都會返回一個新的Tensor, 這個Tensor和原來的Tensor共享內存空間,一個改變,另一個也會隨著改變,且都會設置新的Tensor的requires_grad屬性為False。這兩個方法只取出原來Tensor的tensor數據, 丟棄了grad、grad_fn等額外的信息。
tensor.data是不安全的, 因為 x.data 不能被 autograd 追蹤求微分
這是為什么呢?我們對.data進行進一步探究
import torch a = torch.tensor([4., 5., 6.], requires_grad=True) print("a", a) out = a.sigmoid() print("out", out) print(out.requires_grad) #在進行.data前仍為true result = out.data #共享變量,同時將requires_grad設置為false result.zero_() # 改變c的值,原來的out也會改變 print("result", result) print("out", out) out.sum().backward() # 對原來的out求導, print(a.grad) # 不會報錯,但是結果卻并不正確 '''運行結果為: a tensor([4., 5., 6.], requires_grad=True) out tensor([0.9820, 0.9933, 0.9975], grad_fn=<SigmoidBackward0>) True result tensor([0., 0., 0.]) out tensor([0., 0., 0.], grad_fn=<SigmoidBackward0>) tensor([0., 0., 0.]) '''
由于更改分離之后的變量值result,導致原來的張量out的值也跟著改變了,但是這種改變對于autograd是沒有察覺的,它依然按照求導規則來求導,導致得出完全錯誤的導數值卻渾然不知。
那么我們繼續看看.detach()
可以看到將.data改為.detach()后程序立馬報錯,阻止了非法的修改,安全性很高
我們需要記住的就是:
- .data 是一個屬性,二.detach()是一個方法;
- .data 是不安全的,.detach()是安全的。
補充:關于.data和.cpu().data的各種操作
先上圖
仔細分析:
1.首先a是一個放在GPU上的Variable,a.data是把Variable里的tensor取出來,
? 可以看出與a的差別是:缺少了第一行(Variable containing)
2.a.cpu()和a.data.cpu()是分別把a和a.data放在cpu上,其他的沒區別,另外:a.data.cpu()和a.cpu().data一樣
3.a.data[0] | ?a.cpu().data[0] ?| a.data.cpu()[0]是一樣的,都是把第一個值取出來,類型均為float
4.a.data.cpu().numpy()把tensor轉換成numpy的格式
總結
原文鏈接:https://blog.csdn.net/gary101818/article/details/124658826
相關推薦
- 2022-09-18 ASP.NET?Core實現文件上傳和下載_實用技巧
- 2022-10-23 python如何在一個py文件中獲取另一個py文件中的值(一個或多個)_python
- 2023-02-05 Redis處理高并發之布隆過濾器詳解_Redis
- 2022-04-11 python寫入Excel表格的方法詳解_python
- 2022-04-02 Redis快速實現分布式session的方法詳解_Redis
- 2022-07-10 如何替換重構依賴里面的Service
- 2022-10-02 pandas數據類型之Series的具體使用_python
- 2022-04-28 Go語言錯誤處理異常捕獲+異常拋出_Golang
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支