網站首頁 編程語言 正文
Numpy、Pytorch中的broadcasting
寫在前面
自己一直都不清楚numpy、pytorch里面不同維數的向量之間的element wise的計算究竟是按照什么規則來確認維數匹配和不匹配的情況的,比如
>>> b = np.ones((4,5))
>>> a = np.arange(5)
>>> c = a + b
>>> c.shape
(4, 5)
>>> c
array([[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.]])
上面這種情況就會自動讓a和b的維數匹配,a加到了b的每一行上
>>> b = np.ones((5,4))
>>> a = np.arange(5)
>>> c = a + b
Traceback (most recent call last):
? File "<stdin>", line 1, in <module>
ValueError: operands could not be broadcast together with shapes (5,) (5,4)
這種情況就無法匹配,此時我們希望的是a能自動加到b的每一列上,但結果看來好像不行
雖然一直存在這種疑惑,但因為平時遇到的各種運算都比較簡單,遇到這種不是直接匹配的array的加法第一直覺就是去console里面試一試,報錯就換個姿勢再試一試,總歸問題可以快速地解決,但是最近在寫模型的時候,遇到了繞不過去的問題,所以去查了文檔,本文就以解決那個問題為目標,來解釋清楚pytorch(numpy也是一樣)中的broadcasting semantics的問題
問題描述
我有一個數據Tensor,維數是64 × 2048 64\times204864×2048,現在我想通過對這64 6464個2048 20482048維的向量做attention(也就是做一個加權和)來得到一個2048 20482048維的向量,因為模型的需要,我需要用五組不同的權值向量來計算出五個不同的加權結果,也就是我的計算結果應該是一個5 × 2048 5\times 20485×2048維的向量,因為在64 6464個向量上加權,所以一組權值向量是64 6464維,五組就是5 × 64 5\times 645×64維
嘗試解決
現在我手頭上有兩個Tensor,一個是數據Tensor(64 × 2048 64\times 204864×2048)另一個是權值Tensor(5 × 64 5\times 645×64),我GAN!直到我寫到了這里,我才發現這不是一個矩陣乘法就能解決的問題嘛+_+,當然,我想給自己正名,這里我簡化了一下問題所以才發現原來這么容易就解決了,而原來我在寫代碼的時候因為還要考慮batch_size等問題才云里霧里不知道咋辦,還好當時沒想出來,所以去查了文檔發現了新的東西,然后寫文章的時候想到也算是完滿了(不然也不會發現自己好澇)
以上都是題外話,現在,我們還是考慮用愚蠢的element wise的方法來解決,好在現在有兩種方法可以解決問題,所以我們可以用來相互檢驗一下,element wise的解決方法就是,我希望這5個64維的權值向量分別和這64個2048維的向量進行element wise的乘法,也就是第一個64維權值向量先對64個2048維向量加權得到一個2048維的向量,然后第二個64維權值向量先對64個2048維向量加權得到一個2048維的向量…,以此類推總共五個,最終得到五個64 × 2048 64×204864×2048維的向量,然后求和得到最后的5 × 2048 5×20485×2048維的向量
那么按照平常的習慣,我就去先試試pytorch能不能直接地理解我的想法
>>> import torch
>>> bs = 10 # batch_size
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> out = att * x
Traceback (most recent call last):
? File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (64) must match the size of tensor b (2048) at non-singleton dimension 2
直接乘不行,因為維數是不匹配的,那怎樣的維數才算匹配呢?
BROADCASTING SEMANTICS
以下內容主要來源于自官方文檔
很多pytorch的運算是支持broadcasting semantics的,而簡單來說,如果運算支持broadcast,則參與運算的Tensor會自動進行擴展來使得運算符左右的Tensor維數匹配,而無需人手動地去拷貝其中的某個Tensor,這就類似于我們開頭的那個例子
>>> b = np.ones((4,5))
>>> a = np.arange(5)
>>> c = a + b
>>> c.shape
(4, 5)
>>> c
array([[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.]])
我們無需讓a的維數和b一樣,因為numpy自動幫我們做了
這里的另一個重要的概念是broadcastable,如果兩個Tensor是broadcastable的,那么就可以對他倆使用支持broadcast的運算,比如直接加減乘除
而兩個向量要是broadcast的話,必須滿足以下兩個條件
- 每個tensor至少是一維的
- 兩個tensor的維數從后往前,對應的位置要么是相等的,要么其中一個是1,或者不存在
這是官方的例子解釋
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# 相同維數的tensor一定是broadcastable的
>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# 不是broadcastable的,因為每個tensor維數至少要是1
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( ?3,1,1)
# 是broadcastable的,因為從后往前看,一定要注意是從后往前看!
# 第一個維度都是1,相等,滿足第二個條件
# 第二個維度其中有一個是1,滿足第二個條件
# 第三個維度都是3,相等,滿足第二個條件
# 第四個維度其中有一個不存在,滿足第二個條件
# 但是
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty( ?3,1,1)
# 不是broadcastable的,因為從后往前看第三個維度是不match的 2!=3,且都不是1
如果x和y是broadcastable的,那么結果的tensor的size按照如下的規則計算
- 如果兩者的維度不一樣,那么就自動增加1維(也就是unsqueeze)
- 對于結果的每個維度,它取x和y在那一維上的最大值
官方的例子
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( ?3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
此外,關于broadcast導致的就地(in-place)操作和梯度運算的兼容性等問題,可以自行參考官方文檔
解決問題
上面我們看到,要想兩個Tensor支持element wise的運算,需要它們是broadcastable的,而要想它們是broadcastable的,就需要它們的維度自后向前逐一匹配,回到我們原來的問題中,我們有兩個Tensor x(64 × 2048) att(5 × 64),為了讓它們broadcastable,我們只需要
>>> import torch
>>> bs = 10 # batch_size
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> x = x.unsqueeze(1)
>>> att = att.view(1,*att.shape,1)
>>> x.shape
torch.Size([10, 1, 64, 2048])
>>> att.shape
torch.Size([1, 5, 64, 1])
>>> out = x * att
>>> out.shape
torch.Size([10, 5, 64, 2048])
最后我們來驗證兩種方法是否結果相同
>>> import torch
>>> bs = 10?
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> out1 = torch.matmul(att,x) ?# 直接矩陣相乘
>>> out.shape
torch.Size([10, 5, 2048])
>>> x = x.unsqueeze(1)
>>> att = att.view(1,*att.shape,1)
>>> out2 = x * att ?# element wise的方法
>>> out2 = out2.sum(dim=2)
>>> test = torch.sum((out1-out2)<0.00001) ?# 浮點數有微小的誤差
>>> test
tensor(102400)
>>> out1.numel() ?# 最后表明兩個out向量是相等的
102400
Reference
[1] https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#module-numpy.doc.broadcasting
[2] https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics
總結
原文鏈接:https://blog.csdn.net/luo3300612/article/details/100100291
相關推薦
- 2022-10-25 Python繪制loss曲線和準確率曲線實例代碼_python
- 2022-10-19 Go?熱加載之fresh詳解_Golang
- 2022-06-12 Android開發之保存圖片到相冊的三種方法詳解_Android
- 2022-08-01 C++簡單又輕松建立鏈式二叉樹流程_C 語言
- 2021-12-24 OpenCV?reshape函數實現矩陣元素序列化_C 語言
- 2022-12-22 Nginx配置之main?events塊使用示例詳解_nginx
- 2022-11-16 python?sklearn與pandas實現缺失值數據預處理流程詳解_python
- 2023-02-03 C++中的HTTP協議問題_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支