日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

PyTorch中Torch.arange函數詳解_python

作者:_湘江夜話_ ? 更新時間: 2023-04-03 編程語言

torch.arange函數詳解

官方文檔:torch.arange

函數原型

arange(start=0, end, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor

用法

返回大小為一維張量,其值介于區間 為步長等間隔取值

參數說明

參數 類型 說明
start Number 起始值,默認值:0
end Number 結束值
step Number 步長,默認值:1

關鍵字參數

關鍵字參數 類型 說明
out Tensor 輸出張量
dtype torch.dtype 期望的返回張量的數據類型。默認值:如果是None,則使用全局默認值。如果未給出 dtype,則從其他輸入參數推斷數據類型。如果 start、end 或 stop 中的任何一個是浮點數,則 dtype被推斷為默認值,參見 get_default_dtype()。否則,dtype 被推斷為 torch.int64
layout torch.layout 返回張量的期望 layout。默認值:torch.strided
device torch.device 返回張量的期望設備。默認值:如果是None,則使用當前設備作為默認張量類型,參見torch.set_default_tensor_type()。對于 CPU 類型的張量,則 device 是 CPU ,若是 CUDA 類型的張量,則 device 是當前的 CUDA 設備
requires_grad bool autograd 是否記錄返回張量上所作的操作。默認值:False

代碼示例

    >>> torch.arange(5)  # 默認以 0 為起點
    tensor([ 0,  1,  2,  3,  4])
    >>> torch.arange(1, 4)  # 默認間隔為 1
    tensor([ 1,  2,  3])
    >>> torch.arange(1, 2.5, 0.5)  # 指定間隔 0.5
    tensor([ 1.0000,  1.5000,  2.0000])

pyTorch中torch.range()和torch.arange()的區別

torch.range()和torch.arange()的區別

x = torch.range(-8, 8)
y = torch.arange(-8, 8)
print(x, x.dtype)
print(y, y.dtype)

output:

?? tensor([-8., -7., -6., -5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.,6., 7., 8.]) torch.float32
?? tensor([-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7]) torch.int64

可以看到,torch.range()的范圍是[-8, 8],類型為torch.float32

torch.arange()的范圍是[-8, 8),類型為torch.int64

在梯度設置時會出現錯誤:

x = torch.range(-8, 8, 1, requires_grad=True)
y = torch.arange(-8, 8, 1, requires_grad=True)
print(x, x.dtype)
print(y, y.dtype)

即只有當類型為float時才可設置requires_grad=True,故可將

y = torch.arange(-8, 8, 1, requires_grad=True)

改為以下,即手動改變數據類型即可。

y = torch.arange(-8.0, 8.0, 1.0, requires_grad=True)

output:
?? tensor([-8., -7., -6., -5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.,6., 7., 8.], requires_grad=True)
?? torch.float32
?? tensor([-8., -7., -6., -5., -4., -3., -2., -1., 0., 1., 2., 3., 4., 5.,6., 7.], requires_grad=True)
?? torch.float32

總結

原文鏈接:https://blog.csdn.net/weixin_44504393/article/details/127092330

欄目分類
最近更新