日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網(wǎng)站首頁 編程語言 正文

Pytorch中torch.argmax()函數(shù)使用及說明_python

作者:cv_lhp ? 更新時間: 2023-02-07 編程語言

torch.argmax()函數(shù)解析

1. 官網(wǎng)鏈接

torch.argmax(),如下圖所示:

torch.argmax()

torch.argmax()

2. torch.argmax(input)函數(shù)解析

torch.argmax(input) → LongTensor

將輸入input張量,無論有幾維,首先將其reshape排列成一個一維向量,然后找出這個一維向量里面最大值的索引。

3. 代碼舉例

import torch
x = torch.randn(3,4)
y = torch.argmax(x)#對應(yīng)于x中最大元素的索引值
x,y

輸出結(jié)果如下:

import torch
x = torch.randn(3,4)
y = torch.argmax(x)#對應(yīng)于x中最大元素的索引值
x,y

4. torch.argmax(input,dim) 函數(shù)解析

torch.argmax(input, dim, keepdim=False) → LongTensor

函數(shù)返回其他所有維在這個維度上面張量最大值的索引。

torch.argmax()函數(shù)中dim表示該維度會消失,可以理解為最終結(jié)果該維度大小是1,表示將該維度壓縮成維度大小為1。

舉例理解:

對于一個維度為(d0,d1) 的矩陣來說,dim=1表示求每一行中最大數(shù)的在該行中的列號,最后得到的就是一個維度為(d0,1) 的二維矩陣,最終列這一維度大小為1就要消失了,最終結(jié)果變成一維張量(d0);
dim=0表示求每一列中最大數(shù)的在該列中的行號,最后我們得到的就是一個維度為(1,d1) 的二維矩陣,結(jié)果行這一維度大小為1就要消失了,最終結(jié)果變成一維張量(d1)。

因此,我們想要求每一行最大的列標號,我們就要指定dim=1,表示我們不要列了,保留行的size就可以了。

假如我們想求每一列的最大行標,就可以指定dim=0,表示我們不要行了,求出每一列的最大值的下標,最后得到(1,d1)的一維矩陣。

5. 代碼舉例

5.1 輸入二維張量torch.Size([3, 4]),dim=0表示將dim=0這個維度大小由3壓縮成1,然后找到dim=0這三個值中最大值的索引,這個索引表示dim=0行索引標號,結(jié)果張量維度變?yōu)閠orch.Size([4])。

import torch
x = torch.randn(3,4)
y = torch.argmax(x,dim=0)#dim=0表示將dim=0這個維度大小由3壓縮成1,然后找到dim=0這三個值中最大值的索引,這個索引表示dim=0行索引標號
x,x.shape,y,y.shape

輸出結(jié)果如下:

(tensor([[ 2.6347, ?0.6456, -1.0461, -1.5154],
? ? ? ? ?[-1.3955, -1.2618, -0.5886, -0.5947],
? ? ? ? ?[-1.5272, -2.0960, ?0.9428, -0.9532]]),
?torch.Size([3, 4]),
?tensor([0, 0, 2, 1]),
?torch.Size([4]))

5.2 輸入二維張量torch.Size([3, 4]),dim=1表示將dim=1這個維度大小由4壓縮成1,然后找到dim=1這四個值中最大值的索引,這個索引表示dim=1列索引標號,結(jié)果張量維度變?yōu)閠orch.Size([3])。

import torch
x = torch.randn(3,4)
y = torch.argmax(x,dim=1)#dim=1表示將dim=1這個維度大小由4壓縮成1,然后找到dim=1這四個值中最大值的索引,這個索引表示dim=1列索引標號
x,x.shape,y,y.shape

輸出結(jié)果如下:

(tensor([[ 0.1549, ?0.4331, ?0.3575, ?1.1077],
? ? ? ? ?[ 2.0233, ?2.0085, -0.6101, -1.8547],
? ? ? ? ?[-0.5101, -0.4052, ?0.3458, -0.7802]]),
?torch.Size([3, 4]),
?tensor([3, 0, 2]),
?torch.Size([3]))

5.3 輸入三維張量torch.Size([2, 3, 4]),dim=0表示將dim=0這個維度大小由2壓縮成1,然后找到dim=0這兩個值中最大值的索引,這個索引表示dim=0維索引標號。

dim=0,即將第一個維度消除,也就是將兩個[34]矩陣只保留一個,因此要在兩組中作比較,即將上下兩個[34]的矩陣分別在對應(yīng)的位置上比較大小,結(jié)果矩陣張量維度變?yōu)閠orch.Size([3, 4])。

import torch
x = torch.randn(2,3,4)
y = torch.argmax(x,dim=0)#dim=0表示將dim=0這個維度大小由2壓縮成1,然后找到dim=0這兩個值中最大值的索引,這個索引表示dim=0維索引標號
x,x.shape,y,y.shape

輸出結(jié)果如下:

(tensor([[[-1.4430, ?0.0306, -1.0396, ?0.1219],
? ? ? ? ? [ 0.1016, ?0.0889, ?0.8005, ?0.3320],
? ? ? ? ? [-1.0518, -1.4526, -0.4586, -0.1474]],
?
? ? ? ? ?[[ 1.2274, ?1.5806, ?0.5444, -0.3088],
? ? ? ? ? [-0.8672, ?0.3843, ?1.2377, ?2.1596],
? ? ? ? ? [ 0.0671, ?0.0847, ?0.5607, -0.7492]]]),
?torch.Size([2, 3, 4]),
?tensor([[1, 1, 1, 0],
? ? ? ? ?[0, 1, 1, 1],
? ? ? ? ?[1, 1, 1, 0]]),
?torch.Size([3, 4]))

5.4 輸入三維張量torch.Size([2, 3, 4]),dim=1表示將dim=1這個維度大小由3壓縮成1,然后找到dim=1這三個值中最大值的索引,這個索引表示dim=1維索引標號。

dim=1,即將第二個維度消除(縱向壓縮成一維),結(jié)果矩陣張量維度變?yōu)閠orch.Size([2, 4])。

import torch
x = torch.randn(2,3,4)
y = torch.argmax(x,dim=1)#dim=1表示將dim=1這個維度大小由3壓縮成1,然后找到dim=1這三個值中最大值的索引,這個索引表示dim=1維索引標號
x,x.shape,y,y.shape

輸出結(jié)果如下:

(tensor([[[-1.7136, ?0.5528, ?0.5171, ?1.2978],
? ? ? ? ? [ 1.0250, -0.2687, ?0.6727, -0.2013],
? ? ? ? ? [ 0.1366, -1.0563, ?0.1965, ?1.5303]],
?
? ? ? ? ?[[-0.0048, ?1.6265, -1.0341, -0.3994],
? ? ? ? ? [ 1.5536, ?0.9739, -0.0913, ?0.0889],
? ? ? ? ? [-0.6703, -0.9099, -0.6400, -0.1807]]]),
?torch.Size([2, 3, 4]),
?tensor([[1, 0, 1, 2],
? ? ? ? ?[1, 0, 1, 1]]),
?torch.Size([2, 4]))

5.5 輸入三維張量torch.Size([2, 3, 4]),dim=2表示將dim=2這個維度大小由4壓縮成1,然后找到dim=2這四個值中最大值的索引,這個索引表示dim=2維索引標號。dim=2,即將第三個維度消除(橫向壓縮成一維),結(jié)果矩陣張量維度變?yōu)閠orch.Size([2, 3])。

import torch
x = torch.randn(2,3,4)
y = torch.argmax(x,dim=2)#dim=2表示將dim=2這個維度大小由4壓縮成1,然后找到dim=2這四個值中最大值的索引,這個索引表示dim=2維索引標號
x,x.shape,y,y.shape

輸出結(jié)果如下:

(tensor([[[-0.3493, ?0.8838, ?0.5876, -0.3967],
? ? ? ? ? [-1.5795, ?2.6964, ?0.7266, ?0.3517],
? ? ? ? ? [-0.6949, -1.4385, -0.0993, ?0.1679]],
?
? ? ? ? ?[[-0.4924, -0.8955, ?0.5511, ?0.6287],
? ? ? ? ? [ 0.2338, -0.5787, -0.2081, -1.3032],
? ? ? ? ? [ 0.6429, ?0.0949, ?0.3319, -0.8551]]]),
?torch.Size([2, 3, 4]),
?tensor([[1, 1, 3],
? ? ? ? ?[3, 0, 0]]),
?torch.Size([2, 3]))

總結(jié)

原文鏈接:https://blog.csdn.net/flyingluohaipeng/article/details/125099214

欄目分類
最近更新