網站首頁 編程語言 正文
torch.repeat_interleave()函數解析
1.函數說明
官網:torch.repeat_interleave(),函數說明如下圖所示:
2. 函數原型
torch.repeat_interleave(input, repeats, dim=None) → Tensor
3. 函數功能
沿著指定的維度重復張量的元素
4. 輸入參數
1)input (類型:torch.Tensor):輸入張量
2)repeats(類型:int或torch.Tensor):每個元素的重復次數
3)dim(類型:int)需要重復的維度。默認情況下dim=None,表示將把給定的輸入張量展平(flatten)為向量,然后將每個元素重復repeats次,并返回重復后的張量。
5. 注意
1) 如果不指定dim,則默認將輸入張量扁平化(維數是1,因此這時repeats必須是一個數,不能是數組),并且返回一個扁平化的輸出數組。
2) 返回的數組與輸入數組維數相同,并且除了給定的維度dim,其他維度大小與輸入數組相應維度大小相同
3) repeats:如果傳入數組,則必須是tensor格式。并且只能是一維數組,數組長度與輸入數組input的dim維度大小相同
6. 代碼例子
6.1 輸入一維張量,不指定dim,重復次數為2次,表示將把給定的輸入張量展平(flatten)為向量,然后將每個元素重復2次,并返回重復后的張量。
a = torch.randn(5)
a,torch.repeat_interleave(a,2)
輸出結果如下所示:
(tensor([ 0.4030, -1.1536, -2.4513, ?1.1454, -0.8818]),
?tensor([ 0.4030, ?0.4030, -1.1536, -1.1536, -2.4513, -2.4513, ?1.1454, ?1.1454,
? ? ? ? ?-0.8818, -0.8818]))
6.2 輸入二維張量,不指定dim,重復次數為2次,表示將把給定的輸入張量展平(flatten)為向量,然后將每個元素重復2次,并返回重復后的張量。
a = torch.randn(3,2)
a,a.repeat_interleave(2)
輸出結果如下:
(tensor([[-1.03, -0.32],
? ? ? ? ?[ 0.43, ?0.78],
? ? ? ? ?[ 0.91, -0.11]]),
?tensor([-1.03, -1.03, -0.32, -0.32, ?0.43, ?0.43, ?0.78, ?0.78, ?0.91, ?0.91,
? ? ? ? ?-0.11, -0.11]))
6.3 輸入二維張量,指定dim=0,重復次數為3次,表示把輸入張量每行元素重復3次
a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=0)
輸出結果如下:
(tensor([[ 0.14, ?1.47],
? ? ? ? ?[-1.52, -0.62],
? ? ? ? ?[-0.24, -0.27]]),
?tensor([[ 0.14, ?1.47],
? ? ? ? ?[ 0.14, ?1.47],
? ? ? ? ?[ 0.14, ?1.47],
? ? ? ? ?[-1.52, -0.62],
? ? ? ? ?[-1.52, -0.62],
? ? ? ? ?[-1.52, -0.62],
? ? ? ? ?[-0.24, -0.27],
? ? ? ? ?[-0.24, -0.27],
? ? ? ? ?[-0.24, -0.27]]))
6.4 輸入二維張量,指定dim=1,重復次數為3次,表示把輸入張量每列元素重復3次
a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=1)
輸出結果如下:
(tensor([[-0.81, ?0.56],
? ? ? ? ?[-2.41, -0.56],
? ? ? ? ?[ 0.38, -0.90]]),
?tensor([[-0.81, -0.81, -0.81, ?0.56, ?0.56, ?0.56],
? ? ? ? ?[-2.41, -2.41, -2.41, -0.56, -0.56, -0.56],
? ? ? ? ?[ 0.38, ?0.38, ?0.38, -0.90, -0.90, -0.90]]))
6.5 輸入二維張量,指定dim=0,重復次數為一個張量列表[n1,n2,n3],表示在(dim=0)對應行上面重復n1,n2,n3遍,張量列表的長度必須與dim=0的維度的長度一樣,否則會報錯
a = torch.randn(3,2)
a,torch.repeat_interleave(a,torch.tensor([2,3,4]),dim=0)#表示第一行重復2遍,第二行重復3遍,第三行重復4遍
輸出結果如下:
(tensor([[-0.79, ?0.54],
? ? ? ? ?[-0.47, -0.25],
? ? ? ? ?[-0.13, ?1.03]]),
?tensor([[-0.79, ?0.54],
? ? ? ? ?[-0.79, ?0.54],
? ? ? ? ?[-0.47, -0.25],
? ? ? ? ?[-0.47, -0.25],
? ? ? ? ?[-0.47, -0.25],
? ? ? ? ?[-0.13, ?1.03],
? ? ? ? ?[-0.13, ?1.03],
? ? ? ? ?[-0.13, ?1.03],
? ? ? ? ?[-0.13, ?1.03]]))
7. 與torch.repeat()函數區別
兩個函數方法最大的區別就是repeat_interleave是一個元素一個元素地重復,而repeat是一組元素一組元素地重復.
總結
原文鏈接:https://blog.csdn.net/flyingluohaipeng/article/details/125039411
相關推薦
- 2022-12-06 C++類成員函數后面加const問題_C 語言
- 2023-03-20 解讀C#中ReadString的一些小疑惑_C#教程
- 2022-01-10 修改代碼后,刷新頁面沒有更新的解決辦法。Disable cache禁止
- 2022-03-15 GO + React + Axios Response to preflight request
- 2022-07-19 Python?assert斷言聲明,遇到錯誤則立即返回問題_python
- 2022-07-12 Linux配置nginx開機自啟
- 2022-03-30 .Net?Core以windows服務方式部署_C#教程
- 2022-05-08 總結Python函數參數的六種類型_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支