日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

Pytorch中torch.repeat_interleave()函數使用及說明_python

作者:cv_lhp ? 更新時間: 2023-02-08 編程語言

torch.repeat_interleave()函數解析

1.函數說明

官網:torch.repeat_interleave(),函數說明如下圖所示:

函數說明

2. 函數原型

torch.repeat_interleave(input, repeats, dim=None) → Tensor

3. 函數功能

沿著指定的維度重復張量的元素

4. 輸入參數

1)input (類型:torch.Tensor):輸入張量

2)repeats(類型:int或torch.Tensor):每個元素的重復次數

3)dim(類型:int)需要重復的維度。默認情況下dim=None,表示將把給定的輸入張量展平(flatten)為向量,然后將每個元素重復repeats次,并返回重復后的張量。

5. 注意

1) 如果不指定dim,則默認將輸入張量扁平化(維數是1,因此這時repeats必須是一個數,不能是數組),并且返回一個扁平化的輸出數組。

2) 返回的數組與輸入數組維數相同,并且除了給定的維度dim,其他維度大小與輸入數組相應維度大小相同

3) repeats:如果傳入數組,則必須是tensor格式。并且只能是一維數組,數組長度與輸入數組input的dim維度大小相同

6. 代碼例子

6.1 輸入一維張量,不指定dim,重復次數為2次,表示將把給定的輸入張量展平(flatten)為向量,然后將每個元素重復2次,并返回重復后的張量。

a = torch.randn(5)
a,torch.repeat_interleave(a,2)

輸出結果如下所示:

(tensor([ 0.4030, -1.1536, -2.4513, ?1.1454, -0.8818]),
?tensor([ 0.4030, ?0.4030, -1.1536, -1.1536, -2.4513, -2.4513, ?1.1454, ?1.1454,
? ? ? ? ?-0.8818, -0.8818]))

6.2 輸入二維張量,不指定dim,重復次數為2次,表示將把給定的輸入張量展平(flatten)為向量,然后將每個元素重復2次,并返回重復后的張量。

a = torch.randn(3,2)
a,a.repeat_interleave(2)

輸出結果如下:

(tensor([[-1.03, -0.32],
? ? ? ? ?[ 0.43, ?0.78],
? ? ? ? ?[ 0.91, -0.11]]),
?tensor([-1.03, -1.03, -0.32, -0.32, ?0.43, ?0.43, ?0.78, ?0.78, ?0.91, ?0.91,
? ? ? ? ?-0.11, -0.11]))

6.3 輸入二維張量,指定dim=0,重復次數為3次,表示把輸入張量每行元素重復3次

a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=0)

輸出結果如下:

(tensor([[ 0.14, ?1.47],
? ? ? ? ?[-1.52, -0.62],
? ? ? ? ?[-0.24, -0.27]]),
?tensor([[ 0.14, ?1.47],
? ? ? ? ?[ 0.14, ?1.47],
? ? ? ? ?[ 0.14, ?1.47],
? ? ? ? ?[-1.52, -0.62],
? ? ? ? ?[-1.52, -0.62],
? ? ? ? ?[-1.52, -0.62],
? ? ? ? ?[-0.24, -0.27],
? ? ? ? ?[-0.24, -0.27],
? ? ? ? ?[-0.24, -0.27]]))

6.4 輸入二維張量,指定dim=1,重復次數為3次,表示把輸入張量每列元素重復3次

a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=1)

輸出結果如下:

(tensor([[-0.81, ?0.56],
? ? ? ? ?[-2.41, -0.56],
? ? ? ? ?[ 0.38, -0.90]]),
?tensor([[-0.81, -0.81, -0.81, ?0.56, ?0.56, ?0.56],
? ? ? ? ?[-2.41, -2.41, -2.41, -0.56, -0.56, -0.56],
? ? ? ? ?[ 0.38, ?0.38, ?0.38, -0.90, -0.90, -0.90]]))

6.5 輸入二維張量,指定dim=0,重復次數為一個張量列表[n1,n2,n3],表示在(dim=0)對應行上面重復n1,n2,n3遍,張量列表的長度必須與dim=0的維度的長度一樣,否則會報錯

a = torch.randn(3,2)
a,torch.repeat_interleave(a,torch.tensor([2,3,4]),dim=0)#表示第一行重復2遍,第二行重復3遍,第三行重復4遍

輸出結果如下:

(tensor([[-0.79, ?0.54],
? ? ? ? ?[-0.47, -0.25],
? ? ? ? ?[-0.13, ?1.03]]),
?tensor([[-0.79, ?0.54],
? ? ? ? ?[-0.79, ?0.54],
? ? ? ? ?[-0.47, -0.25],
? ? ? ? ?[-0.47, -0.25],
? ? ? ? ?[-0.47, -0.25],
? ? ? ? ?[-0.13, ?1.03],
? ? ? ? ?[-0.13, ?1.03],
? ? ? ? ?[-0.13, ?1.03],
? ? ? ? ?[-0.13, ?1.03]]))

7. 與torch.repeat()函數區別

兩個函數方法最大的區別就是repeat_interleave是一個元素一個元素地重復,而repeat是一組元素一組元素地重復.

總結

原文鏈接:https://blog.csdn.net/flyingluohaipeng/article/details/125039411

欄目分類
最近更新