網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
PyTorch中torch.nn.functional.cosine_similarity使用詳解_python
作者:JasonLiu1919 ? 更新時(shí)間: 2022-05-27 編程語(yǔ)言概述
根據(jù)官網(wǎng)文檔的描述,其中 dim表示沿著對(duì)應(yīng)的維度計(jì)算余弦相似。那么怎么理解呢?
首先,先介紹下所謂的dim:
a = torch.tensor([[ [1, 2], [3, 4] ], [ [5, 6], [7, 8] ] ], dtype=torch.float) print(a.shape) """ [ [ [1, 2], [3, 4] ], [ [5, 6], [7, 8] ] ] """
假設(shè)有2個(gè)矩陣:[[1, 2], [3, 4]] 和 [[5, 6], [7, 8]]
, 求2者的余弦相似。
按照dim=0求余弦相似:
import torch.nn.functional as F input1 = torch.tensor([[1, 2], [3, 4]], dtype=torch.float) input2 = torch.tensor([[5, 6], [7, 8]], dtype=torch.float) output = F.cosine_similarity(input1, input2, dim=0) print(output)
結(jié)果如下:
tensor([0.9558, 0.9839])
那么,這個(gè)數(shù)值是怎么得來(lái)的?是按照
具體求解如下:
print(F.cosine_similarity(torch.tensor([1,3], dtype=torch.float) , torch.tensor([5,7], dtype=torch.float), dim=0)) print(F.cosine_similarity(torch.tensor([2,4], dtype=torch.float) , torch.tensor([6,8], dtype=torch.float), dim=0))
運(yùn)行結(jié)果如下:
tensor(0.9558)tensor(0.9839)
可以用scipy.spatial
進(jìn)一步佐證:
from scipy import spatial dataSetI = [1,3] dataSetII = [5,7] result = 1 - spatial.distance.cosine(dataSetI, dataSetII) print(result)
運(yùn)行結(jié)果如下:
0.95577900872195
同理:
dataSetI = [2,4] dataSetII = [6,8] result = 1 - spatial.distance.cosine(dataSetI, dataSetII) print(result)
運(yùn)行結(jié)果如下:
0.9838699100999074
按照dim=1求余弦相似:
output = F.cosine_similarity(input1, input2, dim=1) print(output)
運(yùn)行結(jié)果如下:
tensor([0.9734, 0.9972])
同理,用用scipy.spatial
進(jìn)一步佐證:
dataSetI = [1,2] dataSetII = [5,6] result = 1 - spatial.distance.cosine(dataSetI, dataSetII) print(result)
運(yùn)行結(jié)果:0.973417168333576
dataSetI = [3,4] dataSetII = [7,8] result = 1 - spatial.distance.cosine(dataSetI, dataSetII) print(result)
運(yùn)行結(jié)果:
0.9971641204866132
結(jié)果與F.cosine_similarity
相符合。
補(bǔ)充:給定一個(gè)張量,計(jì)算多個(gè)張量與它的余弦相似度,并將計(jì)算得到的余弦相似度標(biāo)準(zhǔn)化。
import torch def get_att_dis(target, behaviored): attention_distribution = [] for i in range(behaviored.size(0)): attention_score = torch.cosine_similarity(target, behaviored[i].view(1, -1)) # 計(jì)算每一個(gè)元素與給定元素的余弦相似度 attention_distribution.append(attention_score) attention_distribution = torch.Tensor(attention_distribution) return attention_distribution / torch.sum(attention_distribution, 0) # 標(biāo)準(zhǔn)化 a = torch.FloatTensor(torch.rand(1, 10)) print('a', a) b = torch.FloatTensor(torch.rand(3, 10)) print('b', b) similarity = get_att_dis(target=a, behaviored=b) print('similarity', similarity)
a tensor([[0.9255, 0.2194, 0.8370, 0.5346, 0.5152, 0.4645, 0.4926, 0.9882, 0.2783,
? ? ? ? ?0.9258]])
b tensor([[0.6874, 0.4054, 0.5739, 0.8017, 0.9861, 0.0154, 0.8513, 0.8427, 0.6669,
? ? ? ? ?0.0694],
? ? ? ? [0.1720, 0.6793, 0.7764, 0.4583, 0.8167, 0.2718, 0.9686, 0.9301, 0.2421,
? ? ? ? ?0.0811],
? ? ? ? [0.2336, 0.4783, 0.5576, 0.6518, 0.9943, 0.6766, 0.0044, 0.7935, 0.2098,
? ? ? ? ?0.0719]])
similarity tensor([0.3448, 0.3318, 0.3234])
總結(jié)
原文鏈接:https://blog.csdn.net/ljp1919/article/details/120643732
相關(guān)推薦
- 2022-11-18 詳解C語(yǔ)言內(nèi)核字符串轉(zhuǎn)換方法_C 語(yǔ)言
- 2022-11-06 Matplotlib學(xué)習(xí)筆記之plt.xticks()用法_python
- 2022-09-15 Python淺析匿名函數(shù)lambda的用法_python
- 2023-06-17 go開源Hugo站點(diǎn)渲染之模板詞法解析_Golang
- 2022-11-02 Android?Studio模擬器運(yùn)行apk文件_Android
- 2022-10-07 ASP.NET?MVC使用Knockout獲取數(shù)組元素索引的2種方法_實(shí)用技巧
- 2022-08-05 Activity supporting ACTION_VIEW is not exported
- 2022-07-21 微信小程序使用vant weapp報(bào)錯(cuò)
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支