網(wǎng)站首頁 編程語言 正文
torch.argmax()函數(shù)解析
1. 官網(wǎng)鏈接
torch.argmax(),如下圖所示:
2. torch.argmax(input)函數(shù)解析
torch.argmax(input) → LongTensor
將輸入input張量,無論有幾維,首先將其reshape排列成一個一維向量,然后找出這個一維向量里面最大值的索引。
3. 代碼舉例
import torch
x = torch.randn(3,4)
y = torch.argmax(x)#對應(yīng)于x中最大元素的索引值
x,y
輸出結(jié)果如下:
import torch
x = torch.randn(3,4)
y = torch.argmax(x)#對應(yīng)于x中最大元素的索引值
x,y
4. torch.argmax(input,dim) 函數(shù)解析
torch.argmax(input, dim, keepdim=False) → LongTensor
函數(shù)返回其他所有維在這個維度上面張量最大值的索引。
torch.argmax()函數(shù)中dim表示該維度會消失,可以理解為最終結(jié)果該維度大小是1,表示將該維度壓縮成維度大小為1。
舉例理解:
對于一個維度為(d0,d1) 的矩陣來說,dim=1表示求每一行中最大數(shù)的在該行中的列號,最后得到的就是一個維度為(d0,1) 的二維矩陣,最終列這一維度大小為1就要消失了,最終結(jié)果變成一維張量(d0);
dim=0表示求每一列中最大數(shù)的在該列中的行號,最后我們得到的就是一個維度為(1,d1) 的二維矩陣,結(jié)果行這一維度大小為1就要消失了,最終結(jié)果變成一維張量(d1)。
因此,我們想要求每一行最大的列標號,我們就要指定dim=1,表示我們不要列了,保留行的size就可以了。
假如我們想求每一列的最大行標,就可以指定dim=0,表示我們不要行了,求出每一列的最大值的下標,最后得到(1,d1)的一維矩陣。
5. 代碼舉例
5.1 輸入二維張量torch.Size([3, 4]),dim=0表示將dim=0這個維度大小由3壓縮成1,然后找到dim=0這三個值中最大值的索引,這個索引表示dim=0行索引標號,結(jié)果張量維度變?yōu)閠orch.Size([4])。
import torch
x = torch.randn(3,4)
y = torch.argmax(x,dim=0)#dim=0表示將dim=0這個維度大小由3壓縮成1,然后找到dim=0這三個值中最大值的索引,這個索引表示dim=0行索引標號
x,x.shape,y,y.shape
輸出結(jié)果如下:
(tensor([[ 2.6347, ?0.6456, -1.0461, -1.5154],
? ? ? ? ?[-1.3955, -1.2618, -0.5886, -0.5947],
? ? ? ? ?[-1.5272, -2.0960, ?0.9428, -0.9532]]),
?torch.Size([3, 4]),
?tensor([0, 0, 2, 1]),
?torch.Size([4]))
5.2 輸入二維張量torch.Size([3, 4]),dim=1表示將dim=1這個維度大小由4壓縮成1,然后找到dim=1這四個值中最大值的索引,這個索引表示dim=1列索引標號,結(jié)果張量維度變?yōu)閠orch.Size([3])。
import torch
x = torch.randn(3,4)
y = torch.argmax(x,dim=1)#dim=1表示將dim=1這個維度大小由4壓縮成1,然后找到dim=1這四個值中最大值的索引,這個索引表示dim=1列索引標號
x,x.shape,y,y.shape
輸出結(jié)果如下:
(tensor([[ 0.1549, ?0.4331, ?0.3575, ?1.1077],
? ? ? ? ?[ 2.0233, ?2.0085, -0.6101, -1.8547],
? ? ? ? ?[-0.5101, -0.4052, ?0.3458, -0.7802]]),
?torch.Size([3, 4]),
?tensor([3, 0, 2]),
?torch.Size([3]))
5.3 輸入三維張量torch.Size([2, 3, 4]),dim=0表示將dim=0這個維度大小由2壓縮成1,然后找到dim=0這兩個值中最大值的索引,這個索引表示dim=0維索引標號。
dim=0,即將第一個維度消除,也就是將兩個[34]矩陣只保留一個,因此要在兩組中作比較,即將上下兩個[34]的矩陣分別在對應(yīng)的位置上比較大小,結(jié)果矩陣張量維度變?yōu)閠orch.Size([3, 4])。
import torch
x = torch.randn(2,3,4)
y = torch.argmax(x,dim=0)#dim=0表示將dim=0這個維度大小由2壓縮成1,然后找到dim=0這兩個值中最大值的索引,這個索引表示dim=0維索引標號
x,x.shape,y,y.shape
輸出結(jié)果如下:
(tensor([[[-1.4430, ?0.0306, -1.0396, ?0.1219],
? ? ? ? ? [ 0.1016, ?0.0889, ?0.8005, ?0.3320],
? ? ? ? ? [-1.0518, -1.4526, -0.4586, -0.1474]],
?
? ? ? ? ?[[ 1.2274, ?1.5806, ?0.5444, -0.3088],
? ? ? ? ? [-0.8672, ?0.3843, ?1.2377, ?2.1596],
? ? ? ? ? [ 0.0671, ?0.0847, ?0.5607, -0.7492]]]),
?torch.Size([2, 3, 4]),
?tensor([[1, 1, 1, 0],
? ? ? ? ?[0, 1, 1, 1],
? ? ? ? ?[1, 1, 1, 0]]),
?torch.Size([3, 4]))
5.4 輸入三維張量torch.Size([2, 3, 4]),dim=1表示將dim=1這個維度大小由3壓縮成1,然后找到dim=1這三個值中最大值的索引,這個索引表示dim=1維索引標號。
dim=1,即將第二個維度消除(縱向壓縮成一維),結(jié)果矩陣張量維度變?yōu)閠orch.Size([2, 4])。
import torch
x = torch.randn(2,3,4)
y = torch.argmax(x,dim=1)#dim=1表示將dim=1這個維度大小由3壓縮成1,然后找到dim=1這三個值中最大值的索引,這個索引表示dim=1維索引標號
x,x.shape,y,y.shape
輸出結(jié)果如下:
(tensor([[[-1.7136, ?0.5528, ?0.5171, ?1.2978],
? ? ? ? ? [ 1.0250, -0.2687, ?0.6727, -0.2013],
? ? ? ? ? [ 0.1366, -1.0563, ?0.1965, ?1.5303]],
?
? ? ? ? ?[[-0.0048, ?1.6265, -1.0341, -0.3994],
? ? ? ? ? [ 1.5536, ?0.9739, -0.0913, ?0.0889],
? ? ? ? ? [-0.6703, -0.9099, -0.6400, -0.1807]]]),
?torch.Size([2, 3, 4]),
?tensor([[1, 0, 1, 2],
? ? ? ? ?[1, 0, 1, 1]]),
?torch.Size([2, 4]))
5.5 輸入三維張量torch.Size([2, 3, 4]),dim=2表示將dim=2這個維度大小由4壓縮成1,然后找到dim=2這四個值中最大值的索引,這個索引表示dim=2維索引標號。dim=2,即將第三個維度消除(橫向壓縮成一維),結(jié)果矩陣張量維度變?yōu)閠orch.Size([2, 3])。
import torch
x = torch.randn(2,3,4)
y = torch.argmax(x,dim=2)#dim=2表示將dim=2這個維度大小由4壓縮成1,然后找到dim=2這四個值中最大值的索引,這個索引表示dim=2維索引標號
x,x.shape,y,y.shape
輸出結(jié)果如下:
(tensor([[[-0.3493, ?0.8838, ?0.5876, -0.3967],
? ? ? ? ? [-1.5795, ?2.6964, ?0.7266, ?0.3517],
? ? ? ? ? [-0.6949, -1.4385, -0.0993, ?0.1679]],
?
? ? ? ? ?[[-0.4924, -0.8955, ?0.5511, ?0.6287],
? ? ? ? ? [ 0.2338, -0.5787, -0.2081, -1.3032],
? ? ? ? ? [ 0.6429, ?0.0949, ?0.3319, -0.8551]]]),
?torch.Size([2, 3, 4]),
?tensor([[1, 1, 3],
? ? ? ? ?[3, 0, 0]]),
?torch.Size([2, 3]))
總結(jié)
原文鏈接:https://blog.csdn.net/flyingluohaipeng/article/details/125099214
相關(guān)推薦
- 2022-07-22 Mybatis為實體類自定義別名的兩種方式
- 2022-07-27 Python中的?enumerate和zip詳情_python
- 2022-06-08 兩步配置解決 IDEA新項目maven依賴問題
- 2022-04-08 深入理解Golang的反射reflect示例_Golang
- 2022-06-29 Oracle中執(zhí)行動態(tài)SQL_oracle
- 2022-03-14 window環(huán)境編譯在linux環(huán)境運行的golang程序
- 2022-11-05 關(guān)于Python?Tkinter?復(fù)選框?->Checkbutton_python
- 2022-12-03 內(nèi)網(wǎng)環(huán)境下registry搭建步驟詳解_docker
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學習環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支