網站首頁 編程語言 正文
文章目錄
- pytorch花式索引提取topk的張量
- 問題設定
- 代碼實現
- 索引方法
- gather方法
- 驗證
- 補充知識
- expand方法
- gather方法
- randint
pytorch花式索引提取topk的張量
問題設定
或者說,有一個(bs, dim, L)的大張量,索引的index形狀為(bs, X),想得到一個(bs, dim, X)的reduced向量。我們在進行topk操作(以減少計算量)的時候經常碰到這種情況。
給出如下兩種實現方法,分別使用花式索引(參考informer的代碼)以及pytorch的gather方法
代碼實現
索引方法
參考https://blog.csdn.net/qq_36560894/article/details/122005808
feature = torch.rand(2,16,4*4)
indices = torch.randint(0,16, (2, 3))
indices
indices_expand = indices.unsqueeze(1).expand(-1, dim, -1).to(torch.long) # (bs, dim, H*W)
indices_expand.shape
indices_expand[:,1,:] # 結果和indices一致,說明在第二個channel上,每個樣本的索引是一樣的
bs,dim=feature.shape[:2]
bs,dim
feature_reduce = feature.view(bs, dim, -1)[torch.arange(bs)[:, None, None], torch.arange(dim)[None,:,None], indices_expand]
feature_reduce.shape
gather方法
reduce_feature = torch.gather(feature, 2, indices_expand)
驗證
兩種方法得到的結果完全相同
補充知識
expand方法
在 PyTorch 中,expand()
方法用于擴展張量的大小。它會在不實際復制數據的情況下,重復張量的元素以填充新的形狀。這個方法可以用于廣播操作,以便在執行一些需要相同形狀的張量之間的數學運算時,使它們具有相同的形狀。
下面是使用 expand()
方法的基本用法:
import torch
# 創建一個原始張量
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 使用 expand 擴展張量的大小
expanded_x = x.expand(2, 3, 4) # 擴展成維度為(2, 3, 4)的張量
print(expanded_x)
在上面的例子中,我們首先創建了一個形狀為 (2, 3)
的原始張量 x
。然后,我們使用 expand()
方法將其擴展成一個維度為 (2, 3, 4)
的新張量 expanded_x
,該張量的形狀是在原始張量形狀的基礎上每個維度都擴展了一倍。
需要注意的是,expand()
方法只能用于增加張量的大小,不能減小。另外,擴展后的張量與原始張量共享底層數據,因此在原始張量上進行的任何修改都會反映在擴展后的張量上,反之亦然。
gather方法
在 PyTorch 中,gather()
方法用于從輸入張量中按照指定索引提取元素。這個方法通常用于根據索引收集特定的元素,例如根據類別索引從分類得分張量中獲取對應類別的得分。
下面是使用 gather()
方法的基本用法:
import torch
# 創建一個輸入張量
input_tensor = torch.tensor([[1, 2],
[3, 4],
[5, 6]])
# 創建一個索引張量
indices = torch.tensor([[0, 0],
[1, 0]])
# 使用 gather 方法根據索引收集元素
output_tensor = torch.gather(input_tensor, dim=1, index=indices)
print(output_tensor)
在上面的例子中,我們首先創建了一個形狀為 (3, 2)
的輸入張量 input_tensor
,以及一個形狀為 (2, 2)
的索引張量 indices
。然后,我們使用 gather()
方法從輸入張量 input_tensor
中按照索引張量 indices
收集元素。
在 gather()
方法中,參數 dim
指定了在哪個維度上進行收集操作,而 index
參數指定了收集元素所使用的索引張量。
需要注意的是,索引張量 indices
的形狀必須與輸出張量的形狀一致,或者是可以廣播成與輸出張量形狀一致的形狀。
randint
torch.randint()
是 PyTorch 中用于生成隨機整數張量的函數。它可以生成一個張量,其中的元素是在指定范圍內隨機抽樣的整數。
下面是 torch.randint()
的基本用法示例:
import torch
# 生成一個形狀為 (3, 3) 的隨機整數張量,范圍是 [0, 10)
random_integers = torch.randint(low=0, high=10, size=(3, 3))
print(random_integers)
在上面的示例中,我們使用了 torch.randint()
函數來生成一個形狀為 (3, 3)
的隨機整數張量,其中的元素取值范圍在閉區間 [low, high)
內,即從 0 到 9。
torch.randint()
函數的主要參數包括:
-
low
:生成的隨機整數的最小值(包含)。 -
high
:生成的隨機整數的最大值(不包含)。 -
size
:生成的張量的形狀。
你也可以不指定 low
參數,默認情況下它為 0。此外,還可以使用其他參數來控制生成的隨機整數張量的設備類型、數據類型等。
原文鏈接:https://blog.csdn.net/bj_zhb/article/details/136110041
- 上一篇:沒有了
- 下一篇:沒有了
相關推薦
- 2024-01-09 JPA查詢——setResultTransformer過期替換
- 2022-09-25 Linux中安裝和配置Redis
- 2022-03-15 pycharm安裝opencv出現錯誤:Could not find a version that
- 2022-04-24 C語言時間函數之strftime()詳解_C 語言
- 2022-08-14 hyper-v如何配置NAT網絡的實現_Hyper-V
- 2022-05-15 C++單例類宏定義,方便快速實現單例類
- 2022-01-12 nvm安裝步驟及各種避坑指南&nvm安裝node-v沒反應&npm,yarn用不了
- 2023-07-06 Mac安裝python3并配置環境變量
- 欄目分類
-
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支