網(wǎng)站首頁 編程語言 正文
楔子
在numpy中有一個tensordot方法,尤其在做機(jī)器學(xué)習(xí)的時候會很有用。估計(jì)有人看到這個名字,會想到tensorflow,沒錯tensorflow里面也有tensordot這個函數(shù)。這個函數(shù)它的作用就是,可以讓兩個不同維度的數(shù)組進(jìn)行相乘。我們來舉個例子:
import numpy as np a = np.random.randint(0, 9, (3, 4)) b = np.random.randint(0, 9, (4, 5)) try: print(a * b) except Exception as e: print(e) # operands could not be broadcast together with shapes (3,4) (4,5) # 很明顯,a和b兩個數(shù)組的維度不一樣,沒辦法相乘 # 但是 print(np.tensordot(a, b, 1)) """ [[32 32 28 28 52] [10 25 40 38 78] [56 7 28 0 42]] """ # 我們看到使用tensordot是可以的
下面我們來看看這個函數(shù)的用法
函數(shù)原型
@array_function_dispatch(_tensordot_dispatcher) def tensordot(a, b, axes=2):
我們看到這個函數(shù)接收三個參數(shù),前兩個就是numpy中數(shù)組,最后一個參數(shù)則是用于指定收縮的軸。它可以接收一個整型、列表、列表里面嵌套列表,具體代表什么含義我們下面舉例說明。
理解axes
axes為整型
如果axes接收的是一個整型:m,那么表示指定數(shù)組a的后n個軸和數(shù)組b的前n個軸分別進(jìn)行內(nèi)積,就是對應(yīng)位置元素相乘、再整體求和。
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) # 顯然這兩個數(shù)組不能直接相乘,但是a和后兩個軸和b的前兩個軸是可以直接相乘的 # 因?yàn)樗鼈兌际?4, 5), 最后結(jié)果的shape為(3, 8) print(np.tensordot(a, b, 2).shape) # (3, 8)
而且這個axes默認(rèn)為2,所以它一般都是針對三維或者三維以上的數(shù)組
但是為了具體理解,后面我們會使用一維、二維數(shù)據(jù)具體舉例說明。現(xiàn)在先看axes取不同的值,會得到什么結(jié)果,先理解一下axes的含義。
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) try: print(np.tensordot(a, b, 1).shape) except Exception as e: print(e) # shape-mismatch for sum # 結(jié)果報(bào)錯了,很好理解,就是形狀不匹配嘛 # axes指定為1,表示a的后一個軸和b的前一個軸進(jìn)行內(nèi)積 # 但是一個是5一個是4,元素?zé)o法一一對應(yīng),所以報(bào)錯,提示shape-mismatch,形狀不匹配 # 這里我們把數(shù)組b的shape改一下,這樣a的后一個軸和b的前一個軸就匹配了,都是5 a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((5, 4, 8)) print(np.tensordot(a, b, 1).shape) # (3, 4, 4, 8) """ 這樣就能夠運(yùn)算了,我們說指定收縮的軸,進(jìn)行內(nèi)積運(yùn)算得到的是一個值 所以這里的(3, 4, 5)和(5, 4, 8)變成了(3, 4, 4, 8) 而上一個例子是(3, 4, 5)和(4, 5, 8),然后axes=2 因?yàn)閍的后兩個軸和b的前兩個軸進(jìn)行內(nèi)積變成了一個具體的值,所以最終的維度就是(3, 8) """
如果axes為0的話,會有什么結(jié)果
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) print(np.tensordot(a, b, 0).shape) # (3, 4, 5, 4, 5, 8) print(np.tensordot(b, a, 0).shape) # (4, 5, 8, 3, 4, 5) """ np.tensordot(a, b, 0)等價(jià)于將a中的每一個元素都和b相乘 然后再將原來a中的對應(yīng)元素替換掉 """
上面的操作也可以使用愛因斯坦求和來實(shí)現(xiàn)
axes=0
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) c1 = np.tensordot(a, b, 0) c2 = np.einsum("ijk,xyz->ijkxyz", a, b) print(c1.shape, c2.shape) # (3, 4, 5, 4, 5, 8) (3, 4, 5, 4, 5, 8) print(np.all(c1 == c2)) # True """ 生成的c1和c2是一樣的 """ c3 = np.tensordot(b, a, 0) c4 = np.einsum("ijk,xyz->xyzijk", a, b) print(c3.shape, c4.shape) # (4, 5, 8, 3, 4, 5) (4, 5, 8, 3, 4, 5) print(np.all(c3 == c4)) # True """ 生成的c3和c4是一樣的 """
那么它們的效率之間孰優(yōu)孰劣呢?我們在jupyter上測試一下
>>> %timeit c1 = np.tensordot(a, b, 0) 50.5 μs ± 206 ns per loop >>> %timeit c2 = np.einsum("ijk,xyz->ijkxyz", a, b) 7.29 μs ± 242 ns per loop
可以看到愛因斯坦求和快了不少
axes=1
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((5, 4, 8)) c1 = np.tensordot(a, b, 1) c2 = np.einsum("ijk,kyz->ijyz", a, b) print(c1.shape, c2.shape) # (3, 4, 4, 8) (3, 4, 4, 8) print(np.all(c1 == c2)) # True
axes=2
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) c1 = np.tensordot(a, b, 2) c2 = np.einsum("ijk,jkz->iz", a, b) print(c1.shape, c2.shape) # (3, 8) (3, 8) print(np.all(c1 == c2)) # True
axes為列表
如果axes接收的是一個列表:[m, n],那么表示讓a的第m+1個(索引為m)
軸和b的第n+1(索引為n)
個軸進(jìn)行內(nèi)積。使用列表的方法最大的好處就是,可以指定任意位置的軸。
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) # 我們看到a的第二個維度(或者說軸)和b的第一個維度都是4,所以它們是可以進(jìn)行內(nèi)積的 c1 = np.tensordot(a, b, [1, 0]) # 由于內(nèi)積的結(jié)果是一個標(biāo)量,所以(3, 4, 5)和(4, 5, 8)在tensordot之后的shape是(3, 5, 5, 8) # 相當(dāng)于把各自的4給扔掉了(因?yàn)樽兂闪藰?biāo)量),然后組合在一起 print(c1.shape) # (3, 5, 5, 8) # 同理a的最后一個維度和b的第二個維度也是可以內(nèi)積的 # 最后一個維度也可以使用-1,等于按照列表的索引來取對應(yīng)的維度 c2 = np.tensordot(a, b, [-1, 1]) print(c2.shape) # (3, 4, 4, 8)
上面的操作也可以使用愛因斯坦求和來實(shí)現(xiàn)
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) c1 = np.tensordot(a, b, [1, 0]) c2 = np.einsum("ijk,jyz->ikyz", a, b) print(c1.shape, c2.shape) # (3, 5, 5, 8) (3, 5, 5, 8) print(np.all(c1 == c2)) # True c3 = np.tensordot(a, b, [-1, 1]) c4 = np.einsum("ijk,akz->ijaz", a, b) print(c3.shape, c4.shape) # (3, 4, 4, 8) (3, 4, 4, 8) print(np.all(c3 == c4)) # True
axes為列表嵌套列表
如果axes接收的是一個嵌套列表的列表:[[m], [n]],等于說可以選多個軸
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) # 我們想讓a的后兩個軸和b的前兩個軸內(nèi)積 c1 = np.tensordot(a, b, axes=2) c2 = np.tensordot(a, b, [[1,2], [0,1]]) print(c1.shape, c2.shape) # (3, 8) (3, 8) print(np.all(c1 == c2)) # True
但是使用列表進(jìn)行篩選還有一個好處,就是可以忽略順序
import numpy as np a = np.arange(60).reshape((4, 3, 5)) b = np.arange(160).reshape((4, 5, 8)) # 這個時候就無法給axes傳遞整型了 c3 = np.tensordot(a, b, [[0, 2], [0, 1]]) print(c3.shape) # (3, 8)
此外,使用列表篩選還有一個強(qiáng)大的功能,就是可以倒著取值
import numpy as np a = np.arange(60).reshape((4, 5, 3)) b = np.arange(160).reshape((5, 4, 8)) # 這個時候我們選擇前兩個軸,但是一個是(4, 5)一個是(5, 4),所以無法相乘 # 因此在選擇的時候需要倒著篩選: # [[0, 1], [1, 0]]-> (4, 5)和(4, 5) 或者 [[1, 0], [0, 1]] -> (5, 4)和(5, 4) c3 = np.tensordot(a, b, [[0, 1], [1, 0]]) print(c3.shape) # (3, 8)
最后同樣看看如何愛因斯坦求和來實(shí)現(xiàn)
import numpy as np a = np.arange(60).reshape((4, 5, 3)) b = np.arange(160).reshape((4, 5, 8)) c1 = np.tensordot(a, b, [[0, 1], [0, 1]]) c2 = np.einsum("ijk,ijz->kz", a, b) print(c1.shape, c2.shape) # (3, 8) (3, 8) print(np.all(c1 == c2)) # True a = np.arange(60).reshape((4, 5, 3)) b = np.arange(160).reshape((5, 4, 8)) c1 = np.tensordot(a, b, [[0, 1], [1, 0]]) c2 = np.einsum("ijk,jiz->kz", a, b) print(c1.shape, c2.shape) # (3, 8) (3, 8) print(np.all(c1 == c2)) # True a = np.arange(60).reshape((4, 3, 5)) b = np.arange(160).reshape((5, 4, 8)) c1 = np.tensordot(a, b, [[0, 2], [1, 0]]) c2 = np.einsum("ijk,kiz->jz", a, b) print(c1.shape, c2.shape) # (3, 8) (3, 8) print(np.all(c1 == c2)) # True
以兩個一維數(shù)組為例
我們來通過打印具體的數(shù)組來看一下tensordot
import numpy as np a = np.array([1, 2, 3]) b = np.array([2, 3, 4]) print(np.tensordot(a, b, axes=0)) """ [[ 2 3 4] [ 4 6 8] [ 6 9 12]] """ print(np.einsum("i,j->ij", a, b)) """ [[ 2 3 4] [ 4 6 8] [ 6 9 12]] """ # 我們axes=0,等于是a的每一個元素和相乘,然后再把原來a對應(yīng)的元素替換掉 # 所以是a中的1 2 3分別和b相乘,得到[2 3 4] [4 6 8] [6 9 12]、再替換掉1 2 3 # 所以結(jié)果是[[2 3 4] [4 6 8] [6 9 12]]
如果axes=1呢?
import numpy as np a = np.array([1, 2, 3]) b = np.array([2, 3, 4]) print(np.tensordot(a, b, axes=1)) # 20 """ 選取a的前一個軸和b的后一個軸進(jìn)行內(nèi)積 而a和b只有一個軸,所以結(jié)果是一個標(biāo)量 """ print(np.einsum("i,i->", a, b)) # 20
如果axes=2呢?首先我們說axes等于一個整型,表示選取a的后n個軸,b的前n個軸,而一維數(shù)組它們只有一個軸
import numpy as np a = np.array([1, 2, 3]) b = np.array([2, 3, 4]) try: print(np.tensordot(a, b, axes=2)) # 20 except Exception as e: print(e) # tuple index out of range
顯然索引越界了。
以一個一維數(shù)組和一個二維數(shù)組為例
我們通過一維數(shù)組和二維數(shù)組進(jìn)行tensordot來感受一下
axes=0
import numpy as np a = np.array([1, 2, 3]) b = np.array([[2, 3, 4]]) print(np.tensordot(a, b, 0)) """ [[[ 2 3 4]] [[ 4 6 8]] [[ 6 9 12]]] """ print(np.einsum("i,jk->ijk", a, b)) """ [[[ 2 3 4]] [[ 4 6 8]] [[ 6 9 12]]] """ # 很好理解,就是1 2 3分別和[[2, 3, 4]]相乘再替換掉 1 2 3 print(np.tensordot(a, b, 0).shape) # (3, 1, 3) ########################## print(np.tensordot(b, a, 0)) """ [[[ 2 4 6] [ 3 6 9] [ 4 8 12]]] """ print(np.einsum("i,jk->jki", a, b)) """ [[[ 2 4 6] [ 3 6 9] [ 4 8 12]]] """ # 很好理解,就是2 3 4分別和[1 2 3]相乘再替換掉 2 3 4 print(np.tensordot(b, a, 0).shape) # (1, 3, 3)
axes=1的話呢?
import numpy as np a = np.array([1, 2, 3]) b = np.array([[2, 3, 4], [4, 5, 6]]) try: print(np.tensordot(a, b, 1)) except Exception as e: print(e) # shape-mismatch for sum # 我們注意到報(bào)錯了,因?yàn)閍xes=1,表示取a的后一個軸和b的前1個軸 # a的shape是(3, 0),所以它的后一個軸和前一個軸對應(yīng)的數(shù)組長度都是3 # 但是b的前一個軸對應(yīng)的數(shù)組長度是2,不匹配所以報(bào)錯 print(np.tensordot(b, a, 1)) # [20 32] # 我們看到這個是可以的,因?yàn)檫@表示b的后一個軸,數(shù)組長度為3,是匹配的 # 讓后一個軸的[2 3 4]、[4 5 6]分別和[1 2 3]進(jìn)行內(nèi)積,最終得到兩個標(biāo)量 try: print(np.einsum("i,ij->ij", a, b)) except Exception as e: print(e) # operands could not be broadcast together with remapped shapes [original->remapped]: (3,)->(3,newaxis) (2,3)->(2,3) # 同樣對于愛因斯坦求和也是無法這么做的,我們需要換個順序 print(np.einsum("i,ji->j", a, b)) # [20 32] # 或者 print(np.einsum("j,ij->i", a, b)) # [20 32]
axes=2的話呢?
import numpy as np a = np.array([1, 2, 3]) b = np.array([[2, 3, 4], [4, 5, 6]]) try: print(np.tensordot(a, b, 2)) except Exception as e: print(e) # tuple index out of range # 我們注意到報(bào)錯了,因?yàn)閍xes=2,表示取a的后兩個軸和b的前兩個軸 # 而a總共才1個軸,所以報(bào)錯了 try: print(np.tensordot(b, a, 2)) except Exception as e: print(e) # shape-mismatch for sum # 我們看到雖然也報(bào)錯了,但是不是報(bào)索引越界。 # 因?yàn)樯厦姹硎救的前兩個軸,雖然a只有一個,但是此時不會索引越界,只是就取一個。如果是取后兩個就會越界了 # 此時b是(2, 3),而a是(3,) 不匹配,可能有人覺得會發(fā)生廣播,但在這里不會
以兩個二維數(shù)組為例
我們再通過兩個二維數(shù)組進(jìn)行tensordot來感受一下
axes=0
import numpy as np a = np.array([[1, 2, 3]]) b = np.array([[2, 3, 4], [4, 5, 6]]) # a_shape: (1, 3) b_shape(3, 3) print(np.tensordot(a, b, 0)) """ [[[[ 2 3 4] [ 4 5 6]] [[ 4 6 8] [ 8 10 12]] [[ 6 9 12] [12 15 18]]]] """ print(np.einsum("ij,xy->ijxy", a, b)) """ [[[[ 2 3 4] [ 4 5 6]] [[ 4 6 8] [ 8 10 12]] [[ 6 9 12] [12 15 18]]]] """ print(np.tensordot(a, b, 0).shape) # (1, 3, 2, 3) ############# print(np.tensordot(b, a, 0)) """ [[[[ 2 4 6]] [[ 3 6 9]] [[ 4 8 12]]] [[[ 4 8 12]] [[ 5 10 15]] [[ 6 12 18]]]] """ print(np.einsum("ij,xy->xyij", a, b)) """ [[[[ 2 4 6]] [[ 3 6 9]] [[ 4 8 12]]] [[[ 4 8 12]] [[ 5 10 15]] [[ 6 12 18]]]] """ print(np.tensordot(b, a, 0).shape) # (2, 3, 1, 3)
axes=1
import numpy as np a = np.array([[1, 2], [3, 4]]) b = np.array([[2, 3, 4], [4, 5, 6]]) # a_shape: (2, 2) b_shape(2, 3) print(np.tensordot(a, b, 1)) """ [[10 13 16] [22 29 36]] """ print(np.einsum("ij,jk->ik", a, b)) """ [[10 13 16] [22 29 36]] """ # 仔細(xì)的你肯定發(fā)現(xiàn)了,此時就相當(dāng)于矩陣的點(diǎn)乘 print(a @ b) """ [[10 13 16] [22 29 36]] """
axes=2
import numpy as np a = np.array([[1, 2], [3, 4]]) b = np.array([[2, 3, 4], [4, 5, 6]]) # a_shape: (2, 2) b_shape(2, 3) # 取后兩個軸顯然不行,因?yàn)?2, 2)和(2, 3)不匹配 try: print(np.tensordot(a, b, 2)) except Exception as e: print(e) # shape-mismatch for sum a = np.array([[1, 2, 3], [2, 2, 2]]) b = np.array([[2, 3, 4], [4, 5, 6]]) print(np.tensordot(a, b, 2)) # 50 print(np.einsum("ij,ij->", a, b)) # 50
最后看即個愛因斯坦求和的例子,感受它和主角tensordot的區(qū)別,當(dāng)然如果不熟悉的愛因斯坦求和的話可以不用看
import numpy as np a = np.random.randint(1, 9, (5, 3, 2, 3)) b = np.random.randint(1, 9, (3, 3, 2)) c1 = a @ b # 多維數(shù)組,默認(rèn)是對最后兩位進(jìn)行點(diǎn)乘 c2 = np.einsum("ijkm,jmn->ijkn", a, b) print(np.all(c1 == c2)) # True print(c2.shape) # (5, 3, 2, 2) print(np.einsum("...km,...mn->...kn", a, b).shape) # (5, 3, 2, 2) # 但如果是 c3 = np.einsum("ijkm,amn->ijkn", a, b) print(c3.shape) # (5, 3, 2, 2) # 由于符號不一樣,所以即使shape一致,但是兩個數(shù)組不一樣 print(np.all(c3 == c1)) # False a = np.random.randint(1, 9, (5, 3, 3, 2)) b = np.random.randint(1, 9, (1, 3, 2)) print(np.einsum("ijmk,jmn->ijkn", a, b).shape) # (5, 3, 2, 2) print(np.einsum("ijkm,jnm->ijkn", a, b).shape) # (5, 3, 3, 3)
原文鏈接:https://www.cnblogs.com/traditional/p/12639487.html
- 上一篇:沒有了
- 下一篇:沒有了
相關(guān)推薦
- 2022-07-07 python?NetworkX庫生成并繪制帶權(quán)無向圖_python
- 2022-07-06 c#?模擬串口通信?SerialPort的實(shí)現(xiàn)示例_C#教程
- 2022-11-23 Golang?Defer基礎(chǔ)操作詳解_Golang
- 2022-08-20 React的生命周期詳解_React
- 2022-05-24 python中的元組與列表及元組的更改_python
- 2022-10-05 Iptables防火墻tcp-flags模塊擴(kuò)展匹配規(guī)則詳解_安全相關(guān)
- 2022-11-09 PostgreSQL索引失效會發(fā)生什么_PostgreSQL
- 2022-01-07 gulp構(gòu)建時報(bào)錯 ReferenceError: primordials is not defin
- 欄目分類
-
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支