日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

PyTorch中關于tensor.repeat()的使用_python

作者:tomeasure ? 更新時間: 2022-12-09 編程語言

關于tensor.repeat()的使用

考慮到很多人在學習這個函數,我想在這里提 一個建議:

強烈推薦 使用 einops 模塊中的 repeat() 函數 替代 tensor.repeat()!

它可以擺脫 tensor.repeat() 參數的神秘主義。

einops 模塊文檔地址:https://nbviewer.jupyter.org/github/arogozhnikov/einops/blob/master/docs/1-einops-basics.ipynb

學習 tensor.repeat() 這個函數的功能的時候,最好還是要觀察所得到的 結果的維度。

不多說,看代碼:

>>> import torch
>>> 
>>> # 定義一個 33x55 張量
>>> a = torch.randn(33, 55)
>>> a.size()
torch.Size([33, 55])
>>> 
>>> # 下面開始嘗試 repeat 函數在不同參數情況下的效果
>>> a.repeat(1,1).size()     # 原始值:torch.Size([33, 55])
torch.Size([33, 55])
>>> 
>>> a.repeat(2,1).size()     # 原始值:torch.Size([33, 55])
torch.Size([66, 55])
>>> 
>>> a.repeat(1,2).size()     # 原始值:torch.Size([33, 55])
torch.Size([33, 110])
>>>
>>> a.repeat(1,1,1).size()   # 原始值:torch.Size([33, 55])
torch.Size([1, 33, 55])
>>>
>>> a.repeat(2,1,1).size()   # 原始值:torch.Size([33, 55])
torch.Size([2, 33, 55])
>>>
>>> a.repeat(1,2,1).size()   # 原始值:torch.Size([33, 55])
torch.Size([1, 66, 55])
>>>
>>> a.repeat(1,1,2).size()   # 原始值:torch.Size([33, 55])
torch.Size([1, 33, 110])
>>>
>>> a.repeat(1,1,1,1).size() # 原始值:torch.Size([33, 55])
torch.Size([1, 1, 33, 55])
>>> 
>>> # ------------------ 割割 ------------------
>>> # repeat()的參數的個數,不能少于被操作的張量的維度的個數,
>>> # 下面是一些錯誤示例
>>> a.repeat(2).size()  # 1D < 2D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> # 定義一個3維的張量,然后展示前面提到的那個錯誤
>>> b = torch.randn(5,6,7)
>>> b.size() # 3D
torch.Size([5, 6, 7])
>>> 
>>> b.repeat(2).size() # 1D < 3D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1).size() # 2D < 3D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1,1).size() # 3D = 3D, okay
torch.Size([10, 6, 7])
>>>

Tensor.repeat()的簡單用法

相當于手動實現廣播機制,即沿著給定的維度對tensor進行重復:

比如說對下面x的第1個通道復制三次,其余通道保持不變:

import torch

x = torch.randn(1, 3, 224, 224)
y = x.repeat(3, 1, 1, 1)
print(x.shape)
print(y.shape)

結果為:

torch.Size([1, 3, 224, 224])
torch.Size([3, 3, 224, 224])

這個在復制batch的時候用的比較多,上面的情況就相當于batch為1的3×224×224特征圖復制成了batch為3

原文鏈接:https://blog.csdn.net/qq_29695701/article/details/89763168

欄目分類
最近更新