網站首頁 編程語言 正文
函數作用:
該函數的作用即按字面意思理解,topk:取數組的前k個元素進行排序。
通常該函數返回2個值,第一個值為排序的數組,第二個值為該數組中獲取到的元素在原數組中的位置標號。
舉個栗子:
import numpy as np import torch import torch.utils.data.dataset as Dataset from torch.utils.data import Dataset,DataLoader ####################準備一個數組######################### tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10], [3,4,5,1,1,1,1,1,1,1,1], [7,8,9,1,1,1,1,1,1,1,1], [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32) ####################打印這個原數組######################### print('tensor1:') print(tensor1) #################使用torch.topk()這個函數################## print('使用torch.topk()這個函數得到:') '''k=3代表從原數組中取得3個元素,dim=1表示從原數組中的第一維獲取元素 (在本例中是分別從[10,1,2,1,1,1,1,1,1,1,10]、[3,4,5,1,1,1,1,1,1,1,1]、 [7,8,9,1,1,1,1,1,1,1,1]、[1,4,7,1,1,1,1,1,1,1,1]這四個數組中獲取3個元素) 其中largest=True表示從大到小取元素''' print(torch.topk(tensor1, k=3, dim=1, largest=True)) #################打印這個函數第一個返回值#################### print('函數第一個返回值topk[0]如下') print(torch.topk(tensor1, k=3, dim=1, largest=True)[0]) #################打印這個函數第二個返回值#################### print('函數第二個返回值topk[1]如下') print(torch.topk(tensor1, k=3, dim=1, largest=True)[1]) ''' #######################運行結果########################## tensor1: tensor([[10., 1., 2., 1., 1., 1., 1., 1., 1., 1., 10.], [ 3., 4., 5., 1., 1., 1., 1., 1., 1., 1., 1.], [ 7., 8., 9., 1., 1., 1., 1., 1., 1., 1., 1.], [ 1., 4., 7., 1., 1., 1., 1., 1., 1., 1., 1.]]) 使用torch.topk()這個函數得到: '得到的values是原數組dim=1的四組從大到小的三個元素值; 得到的indices是獲取到的元素值在原數組dim=1中的位置。' torch.return_types.topk( values=tensor([[10., 10., 2.], [ 5., 4., 3.], [ 9., 8., 7.], [ 7., 4., 1.]]), indices=tensor([[ 0, 10, 2], [ 2, 1, 0], [ 2, 1, 0], [ 2, 1, 0]])) 函數第一個返回值topk[0]如下 tensor([[10., 10., 2.], [ 5., 4., 3.], [ 9., 8., 7.], [ 7., 4., 1.]]) 函數第二個返回值topk[1]如下 tensor([[ 0, 10, 2], [ 2, 1, 0], [ 2, 1, 0], [ 2, 1, 0]]) '''
該函數功能經常用來獲取張量或者數組中最大或者最小的元素以及索引位置,是一個經常用到的基本函數。
實例演示
任務一:
取top1(最大值):
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053], [ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823], [-0.4451, 0.1673, 1.2590, -2.0757, 1.7255], [ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]]) print(pred) values, indices = pred.topk(1, dim=0, largest=True, sorted=True) print(indices) print(values) # 用max得到的結果,設置keepdim為True,避免降維。因為topk函數返回的index不降維,shape和輸入一致。 _, indices_max = pred.max(dim=0, keepdim=True) print(indices_max) print(indices_max == indices) 輸出: tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053], [ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823], [-0.4451, 0.1673, 1.2590, -2.0757, 1.7255], [ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]]) tensor([[1, 1, 1, 1, 1]]) tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]]) tensor([[1, 1, 1, 1, 1]]) tensor([[True, True, True, True, True]])
任務二:
按行取出topk,將小于topk的置為inf:
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053], [ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823], [-0.4451, 0.1673, 1.2590, -2.0757, 1.7255], [ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]]) print(pred) top_k = 2 # 按行求出每一行的最大的前兩個值 filter_value=-float('Inf') indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None] print(indices_to_remove) pred[indices_to_remove] = filter_value # 對于topk之外的其他元素的logits值設為負無窮 print(pred) 輸出: tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053], [ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823], [-0.4451, 0.1673, 1.2590, -2.0757, 1.7255], [ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]]) tensor([[4], [4], [4], [3]]) tensor([[0.4053], [1.8823], [1.7255], [0.3849]]) tensor([[ True, False, True, True, False], [ True, False, True, True, False], [ True, True, False, True, False], [ True, False, True, False, True]]) tensor([[ -inf, -0.3873, -inf, -inf, 0.4053], [ -inf, 1.4164, -inf, -inf, 1.8823], [ -inf, -inf, 1.2590, -inf, 1.7255], [ -inf, 0.3041, -inf, 0.3849, -inf]])
任務三:
import numpy as np import torch import torch.utils.data.dataset as Dataset from torch.utils.data import Dataset,DataLoader tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10], [3,4,5,1,1,1,1,1,1,1,1], [7,8,9,1,1,1,1,1,1,1,1], [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32) # tensor2=torch.tensor([[3,2,1], # [6,5,4], # [1,4,7], # [9,8,7]],dtype=torch.float32) # print('tensor1:') print(tensor1) print('直接輸出topk,會得到兩個東西,我們需要的是第二個indices') print(torch.topk(tensor1, k=3, dim=1, largest=True)) print('topk[0]如下') print(torch.topk(tensor1, k=3, dim=1, largest=True)[0]) print('topk[1]如下') print(torch.topk(tensor1, k=3, dim=1, largest=True)[1]) ''' tensor1: tensor([[10., 1., 2., 1., 1., 1., 1., 1., 1., 1., 10.], [ 3., 4., 5., 1., 1., 1., 1., 1., 1., 1., 1.], [ 7., 8., 9., 1., 1., 1., 1., 1., 1., 1., 1.], [ 1., 4., 7., 1., 1., 1., 1., 1., 1., 1., 1.]]) 直接輸出topk,會得到兩個東西,我們需要的是第二個indices torch.return_types.topk( values=tensor([[10., 10., 2.], [ 5., 4., 3.], [ 9., 8., 7.], [ 7., 4., 1.]]), indices=tensor([[ 0, 10, 2], [ 2, 1, 0], [ 2, 1, 0], [ 2, 1, 0]])) topk[0]如下 tensor([[10., 10., 2.], [ 5., 4., 3.], [ 9., 8., 7.], [ 7., 4., 1.]]) topk[1]如下 tensor([[ 0, 10, 2], [ 2, 1, 0], [ 2, 1, 0], [ 2, 1, 0]]) '''
總結
原文鏈接:https://blog.csdn.net/qq_45193872/article/details/119878804
相關推薦
- 2022-12-19 Pytorch相關知識介紹與應用_python
- 2022-04-19 C語言內存管理及初始化細節示例詳解_C 語言
- 2022-08-21 golang字符串本質與原理詳解_Golang
- 2022-04-18 python?request?post?列表的方法詳解_python
- 2022-05-08 一篇文章詳細解釋C++的友元(friend)_C 語言
- 2022-12-09 Python中的main函數與import用法_python
- 2023-01-17 React?Hooks核心原理深入分析講解_React
- 2022-12-12 Android?Binder進程間通信工具AIDL使用示例深入分析_Android
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支