網(wǎng)站首頁 編程語言 正文
文章目錄
- pytorch花式索引提取topk的張量
- 問題設(shè)定
- 代碼實(shí)現(xiàn)
- 索引方法
- gather方法
- 驗(yàn)證
- 補(bǔ)充知識
- expand方法
- gather方法
- randint
pytorch花式索引提取topk的張量
問題設(shè)定
或者說,有一個(bs, dim, L)的大張量,索引的index形狀為(bs, X),想得到一個(bs, dim, X)的reduced向量。我們在進(jìn)行topk操作(以減少計算量)的時候經(jīng)常碰到這種情況。
給出如下兩種實(shí)現(xiàn)方法,分別使用花式索引(參考informer的代碼)以及pytorch的gather方法
代碼實(shí)現(xiàn)
索引方法
參考https://blog.csdn.net/qq_36560894/article/details/122005808
feature = torch.rand(2,16,4*4)
indices = torch.randint(0,16, (2, 3))
indices
indices_expand = indices.unsqueeze(1).expand(-1, dim, -1).to(torch.long) # (bs, dim, H*W)
indices_expand.shape
indices_expand[:,1,:] # 結(jié)果和indices一致,說明在第二個channel上,每個樣本的索引是一樣的
bs,dim=feature.shape[:2]
bs,dim
feature_reduce = feature.view(bs, dim, -1)[torch.arange(bs)[:, None, None], torch.arange(dim)[None,:,None], indices_expand]
feature_reduce.shape
gather方法
reduce_feature = torch.gather(feature, 2, indices_expand)
驗(yàn)證
兩種方法得到的結(jié)果完全相同
補(bǔ)充知識
expand方法
在 PyTorch 中,expand()
方法用于擴(kuò)展張量的大小。它會在不實(shí)際復(fù)制數(shù)據(jù)的情況下,重復(fù)張量的元素以填充新的形狀。這個方法可以用于廣播操作,以便在執(zhí)行一些需要相同形狀的張量之間的數(shù)學(xué)運(yùn)算時,使它們具有相同的形狀。
下面是使用 expand()
方法的基本用法:
import torch
# 創(chuàng)建一個原始張量
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 使用 expand 擴(kuò)展張量的大小
expanded_x = x.expand(2, 3, 4) # 擴(kuò)展成維度為(2, 3, 4)的張量
print(expanded_x)
在上面的例子中,我們首先創(chuàng)建了一個形狀為 (2, 3)
的原始張量 x
。然后,我們使用 expand()
方法將其擴(kuò)展成一個維度為 (2, 3, 4)
的新張量 expanded_x
,該張量的形狀是在原始張量形狀的基礎(chǔ)上每個維度都擴(kuò)展了一倍。
需要注意的是,expand()
方法只能用于增加張量的大小,不能減小。另外,擴(kuò)展后的張量與原始張量共享底層數(shù)據(jù),因此在原始張量上進(jìn)行的任何修改都會反映在擴(kuò)展后的張量上,反之亦然。
gather方法
在 PyTorch 中,gather()
方法用于從輸入張量中按照指定索引提取元素。這個方法通常用于根據(jù)索引收集特定的元素,例如根據(jù)類別索引從分類得分張量中獲取對應(yīng)類別的得分。
下面是使用 gather()
方法的基本用法:
import torch
# 創(chuàng)建一個輸入張量
input_tensor = torch.tensor([[1, 2],
[3, 4],
[5, 6]])
# 創(chuàng)建一個索引張量
indices = torch.tensor([[0, 0],
[1, 0]])
# 使用 gather 方法根據(jù)索引收集元素
output_tensor = torch.gather(input_tensor, dim=1, index=indices)
print(output_tensor)
在上面的例子中,我們首先創(chuàng)建了一個形狀為 (3, 2)
的輸入張量 input_tensor
,以及一個形狀為 (2, 2)
的索引張量 indices
。然后,我們使用 gather()
方法從輸入張量 input_tensor
中按照索引張量 indices
收集元素。
在 gather()
方法中,參數(shù) dim
指定了在哪個維度上進(jìn)行收集操作,而 index
參數(shù)指定了收集元素所使用的索引張量。
需要注意的是,索引張量 indices
的形狀必須與輸出張量的形狀一致,或者是可以廣播成與輸出張量形狀一致的形狀。
randint
torch.randint()
是 PyTorch 中用于生成隨機(jī)整數(shù)張量的函數(shù)。它可以生成一個張量,其中的元素是在指定范圍內(nèi)隨機(jī)抽樣的整數(shù)。
下面是 torch.randint()
的基本用法示例:
import torch
# 生成一個形狀為 (3, 3) 的隨機(jī)整數(shù)張量,范圍是 [0, 10)
random_integers = torch.randint(low=0, high=10, size=(3, 3))
print(random_integers)
在上面的示例中,我們使用了 torch.randint()
函數(shù)來生成一個形狀為 (3, 3)
的隨機(jī)整數(shù)張量,其中的元素取值范圍在閉區(qū)間 [low, high)
內(nèi),即從 0 到 9。
torch.randint()
函數(shù)的主要參數(shù)包括:
-
low
:生成的隨機(jī)整數(shù)的最小值(包含)。 -
high
:生成的隨機(jī)整數(shù)的最大值(不包含)。 -
size
:生成的張量的形狀。
你也可以不指定 low
參數(shù),默認(rèn)情況下它為 0。此外,還可以使用其他參數(shù)來控制生成的隨機(jī)整數(shù)張量的設(shè)備類型、數(shù)據(jù)類型等。
原文鏈接:https://blog.csdn.net/bj_zhb/article/details/136110041
- 上一篇:沒有了
- 下一篇:沒有了
相關(guān)推薦
- 2022-09-06 詳解pygame中Rect對象_python
- 2021-11-09 C++11?thread多線程編程創(chuàng)建方式_C 語言
- 2022-02-03 yii joinwith查數(shù)據(jù)的問題
- 2022-06-14 全面了解C語言?static?關(guān)鍵字_C 語言
- 2023-01-30 delphi?判斷字符串是否包含漢字,正則版與非正則版實(shí)現(xiàn)_Delphi
- 2022-10-16 基于epoll的多線程網(wǎng)絡(luò)服務(wù)程序設(shè)計_C 語言
- 2022-10-28 React中使用react-file-viewer問題_React
- 2022-04-12 C#?實(shí)例解釋面向?qū)ο缶幊讨械膯我还δ茉瓌t(示例代碼)_C#教程
- 欄目分類
-
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支