日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學(xué)無先后,達(dá)者為師

網(wǎng)站首頁 編程語言 正文

pytorch花式索引提取topk的張量

作者:bj_zhb 更新時間: 2024-02-17 編程語言

文章目錄

  • pytorch花式索引提取topk的張量
    • 問題設(shè)定
    • 代碼實(shí)現(xiàn)
      • 索引方法
      • gather方法
      • 驗(yàn)證
    • 補(bǔ)充知識
      • expand方法
      • gather方法
      • randint

pytorch花式索引提取topk的張量

問題設(shè)定

在這里插入圖片描述
或者說,有一個(bs, dim, L)的大張量,索引的index形狀為(bs, X),想得到一個(bs, dim, X)的reduced向量。我們在進(jìn)行topk操作(以減少計算量)的時候經(jīng)常碰到這種情況。
給出如下兩種實(shí)現(xiàn)方法,分別使用花式索引(參考informer的代碼)以及pytorch的gather方法

代碼實(shí)現(xiàn)

索引方法

參考https://blog.csdn.net/qq_36560894/article/details/122005808

feature = torch.rand(2,16,4*4)
indices = torch.randint(0,16, (2, 3))
indices
indices_expand = indices.unsqueeze(1).expand(-1, dim, -1).to(torch.long) # (bs, dim, H*W)
indices_expand.shape
indices_expand[:,1,:] # 結(jié)果和indices一致,說明在第二個channel上,每個樣本的索引是一樣的
bs,dim=feature.shape[:2]
bs,dim 
feature_reduce = feature.view(bs, dim, -1)[torch.arange(bs)[:, None, None], torch.arange(dim)[None,:,None], indices_expand]
feature_reduce.shape

在這里插入圖片描述
在這里插入圖片描述

gather方法

reduce_feature = torch.gather(feature, 2, indices_expand)

驗(yàn)證

兩種方法得到的結(jié)果完全相同
在這里插入圖片描述

補(bǔ)充知識

expand方法

在 PyTorch 中,expand() 方法用于擴(kuò)展張量的大小。它會在不實(shí)際復(fù)制數(shù)據(jù)的情況下,重復(fù)張量的元素以填充新的形狀。這個方法可以用于廣播操作,以便在執(zhí)行一些需要相同形狀的張量之間的數(shù)學(xué)運(yùn)算時,使它們具有相同的形狀。

下面是使用 expand() 方法的基本用法:

import torch

# 創(chuàng)建一個原始張量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 使用 expand 擴(kuò)展張量的大小
expanded_x = x.expand(2, 3, 4)  # 擴(kuò)展成維度為(2, 3, 4)的張量

print(expanded_x)

在上面的例子中,我們首先創(chuàng)建了一個形狀為 (2, 3) 的原始張量 x。然后,我們使用 expand() 方法將其擴(kuò)展成一個維度為 (2, 3, 4) 的新張量 expanded_x,該張量的形狀是在原始張量形狀的基礎(chǔ)上每個維度都擴(kuò)展了一倍。

需要注意的是,expand() 方法只能用于增加張量的大小,不能減小。另外,擴(kuò)展后的張量與原始張量共享底層數(shù)據(jù),因此在原始張量上進(jìn)行的任何修改都會反映在擴(kuò)展后的張量上,反之亦然。

gather方法

在 PyTorch 中,gather() 方法用于從輸入張量中按照指定索引提取元素。這個方法通常用于根據(jù)索引收集特定的元素,例如根據(jù)類別索引從分類得分張量中獲取對應(yīng)類別的得分。

下面是使用 gather() 方法的基本用法:

import torch

# 創(chuàng)建一個輸入張量
input_tensor = torch.tensor([[1, 2],
                             [3, 4],
                             [5, 6]])

# 創(chuàng)建一個索引張量
indices = torch.tensor([[0, 0],
                        [1, 0]])

# 使用 gather 方法根據(jù)索引收集元素
output_tensor = torch.gather(input_tensor, dim=1, index=indices)

print(output_tensor)

在上面的例子中,我們首先創(chuàng)建了一個形狀為 (3, 2) 的輸入張量 input_tensor,以及一個形狀為 (2, 2) 的索引張量 indices。然后,我們使用 gather() 方法從輸入張量 input_tensor 中按照索引張量 indices 收集元素。

gather() 方法中,參數(shù) dim 指定了在哪個維度上進(jìn)行收集操作,而 index 參數(shù)指定了收集元素所使用的索引張量。

需要注意的是,索引張量 indices 的形狀必須與輸出張量的形狀一致,或者是可以廣播成與輸出張量形狀一致的形狀。

randint

torch.randint() 是 PyTorch 中用于生成隨機(jī)整數(shù)張量的函數(shù)。它可以生成一個張量,其中的元素是在指定范圍內(nèi)隨機(jī)抽樣的整數(shù)。

下面是 torch.randint() 的基本用法示例:

import torch

# 生成一個形狀為 (3, 3) 的隨機(jī)整數(shù)張量,范圍是 [0, 10)
random_integers = torch.randint(low=0, high=10, size=(3, 3))

print(random_integers)

在上面的示例中,我們使用了 torch.randint() 函數(shù)來生成一個形狀為 (3, 3) 的隨機(jī)整數(shù)張量,其中的元素取值范圍在閉區(qū)間 [low, high) 內(nèi),即從 0 到 9。

torch.randint() 函數(shù)的主要參數(shù)包括:

  • low:生成的隨機(jī)整數(shù)的最小值(包含)。
  • high:生成的隨機(jī)整數(shù)的最大值(不包含)。
  • size:生成的張量的形狀。

你也可以不指定 low 參數(shù),默認(rèn)情況下它為 0。此外,還可以使用其他參數(shù)來控制生成的隨機(jī)整數(shù)張量的設(shè)備類型、數(shù)據(jù)類型等。

原文鏈接:https://blog.csdn.net/bj_zhb/article/details/136110041

  • 上一篇:沒有了
  • 下一篇:沒有了
欄目分類
最近更新