日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

PyTorch中torch.tensor()和torch.to_tensor()的區別_python

作者:Enzo?想砸電腦 ? 更新時間: 2023-03-22 編程語言

前言

在跑模型的時候,遇到如下報錯

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

網上查了一下,發現將 torch.tensor() 改寫成 torch.as_tensor() 就可以避免報錯了。

# 如下寫法報錯
 feature = torch.tensor(image, dtype=torch.float32)
 
# 改為
feature = torch.as_tensor(image, dtype=torch.float32)

然后就又仔細研究了下 torch.as_tensor()torch.tensor() 的區別,在此記錄。

1、torch.as_tensor()

new_data = torch.as_tensor(data, dtype=None,device=None)->Tensor

作用:生成一個新的 tensor, 這個新生成的tensor 會根據原數據的實際情況,來決定是進行淺拷貝,還是深拷貝。當然,會優先淺拷貝,淺拷貝會共享內存,并共享 autograd 歷史記錄。

情況一:數據類型相同 且 device相同,會進行淺拷貝,共享內存

import numpy
import torch

a = numpy.array([1, 2, 3])
t = torch.as_tensor(a)
t[0] = -1

print(a)   # [-1  2  3]
print(a.dtype)   # int64
print(t)   # tensor([-1,  2,  3])
print(t.dtype)   # torch.int64
import numpy
import torch

a = torch.tensor([1, 2, 3], device=torch.device('cuda'))
t = torch.as_tensor(a)
t[0] = -1

print(a)   # tensor([-1,  2,  3], device='cuda:0')
print(t)   # tensor([-1,  2,  3], device='cuda:0')

情況二: 數據類型相同,但是device不同,深拷貝,不再共享內存

import numpy
import torch

import numpy
a = numpy.array([1, 2, 3])
t = torch.as_tensor(a, device=torch.device('cuda'))
t[0] = -1

print(a)   # [1 2 3]
print(a.dtype)   # int64
print(t)   # tensor([-1,  2,  3], device='cuda:0')
print(t.dtype)   # torch.int64

情況三:device相同,但數據類型不同,深拷貝,不再共享內存

import numpy
import torch

a = numpy.array([1, 2, 3])
t = torch.as_tensor(a, dtype=torch.float32)
t[0] = -1

print(a)   # [1 2 3]
print(a.dtype)   # int64
print(t)   # tensor([-1.,  2.,  3.])
print(t.dtype)   # torch.float32

2、torch.tensor()

torch.tensor() 是深拷貝方式。

torch.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)

深拷貝:會拷貝 數據類型 和 device,不會記錄 autograd 歷史 (also known as a “leaf tensor” 葉子tensor)

重點是:

  • 如果原數據的數據類型是:list, tuple, NumPy ndarray, scalar, and other types,不會 waring
  • 如果原數據的數據類型是:tensor,使用 torch.tensor(data) 就會報waring
# 原數據類型是:tensor 會發出警告
import numpy
import torch

a = torch.tensor([1, 2, 3], device=torch.device('cuda'))
t = torch.tensor(a)
t[0] = -1

print(a)
print(t)

# 輸出:
# tensor([1, 2, 3], device='cuda:0')
# tensor([-1,  2,  3], device='cuda:0')
# /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
# 原數據類型是:list, tuple, NumPy ndarray, scalar, and other types, 沒警告
import torch
import numpy

a =  numpy.array([1, 2, 3])
t = torch.tensor(a) 

b = [1,2,3]
t= torch.tensor(b)

c = (1,2,3)
t= torch.tensor(c)

結論就是:以后盡量用 torch.as_tensor()

總結

原文鏈接:https://blog.csdn.net/weixin_37804469/article/details/128767214

欄目分類
最近更新