網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
NCL:Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning,代碼解讀
作者:只想做個(gè)咸魚 更新時(shí)間: 2022-09-22 編程語(yǔ)言一、前言
1、背景
(1)用戶-項(xiàng)目交互數(shù)據(jù)通常是稀疏或嘈雜的,并且它可能無(wú)法學(xué)習(xí)可靠的表示,因?yàn)榛趫D的方法可能更容易受到數(shù)據(jù)稀疏性的影響
(2)現(xiàn)有的基于 GNN 的 CF 方法依賴于顯式交互鏈接來(lái)學(xué)習(xí)節(jié)點(diǎn)表示,而不能顯式利用高階關(guān)系或約束(例如,用戶或項(xiàng)目相似性)來(lái)豐富圖信息,盡管最近的幾項(xiàng)研究利用對(duì)比學(xué)習(xí)來(lái)緩解交互數(shù)據(jù)的稀疏性,但它們通過(guò)隨機(jī)抽樣節(jié)點(diǎn)或損壞子圖來(lái)構(gòu)建對(duì)比對(duì),缺乏構(gòu)建針對(duì)推薦任務(wù)更有意義的對(duì)比學(xué)習(xí)任務(wù)的思考。
2、做出的貢獻(xiàn)
提出NCL方法,主要從兩方面考慮對(duì)比關(guān)系,
(1)結(jié)構(gòu)鄰居?: 通過(guò)高階路徑在結(jié)構(gòu)上連接的節(jié)點(diǎn)
考慮圖結(jié)構(gòu)上的用戶-用戶鄰居,商品-商品鄰居的對(duì)比關(guān)系
(2)語(yǔ)義鄰居?: 語(yǔ)義上相似的鄰居,在圖上可能不直接相鄰。
從節(jié)點(diǎn)表征出發(fā),聚類后,節(jié)點(diǎn)與聚類中心構(gòu)成對(duì)比關(guān)系
二、模型構(gòu)建
1、圖協(xié)同過(guò)濾
這里其實(shí)就是lightGCN的傳播機(jī)制,簡(jiǎn)單過(guò)一下:
GCN的消息傳遞
將每層的輸出組合起來(lái),形成結(jié)點(diǎn)的最終表示
?然后就是預(yù)測(cè),和BPR的損失函數(shù)
?這一部分是基礎(chǔ),如果不熟悉的話可以回看往期的lightGCN介紹
2、結(jié)構(gòu)鄰居的對(duì)比學(xué)習(xí)
提出將每個(gè)用戶(或項(xiàng)目)與他/她的結(jié)構(gòu)鄰居進(jìn)行對(duì)比,這些鄰居的表示通過(guò)GNN的層傳播進(jìn)行聚合。
?交互圖 G 是一個(gè)二分圖,基于 GNN 的模型在圖上的偶數(shù)次信息傳播自然地聚合了同構(gòu)結(jié)構(gòu)鄰居的信息,就可以從GNN模型的偶數(shù)層(如2,4,6)輸出中得到同類鄰居的表示,我們將用戶自身的嵌入和偶數(shù)層GNN的相應(yīng)輸出的嵌入視為正對(duì)。基于InfoNCE[20],我們提出了結(jié)構(gòu)對(duì)比學(xué)習(xí)目標(biāo)來(lái)最小化它們之間的距離:
?其中?為GNN中??層的歸一化輸出,??為偶數(shù)。??是softmax的溫度超參數(shù),同理。item的一樣
?完整的結(jié)構(gòu)對(duì)比目標(biāo)函數(shù)是上述兩個(gè)損失的加權(quán)之和:
?其中??是一個(gè)超參數(shù),以平衡結(jié)構(gòu)對(duì)比學(xué)習(xí)中兩個(gè)損失的權(quán)重。
?3、語(yǔ)義鄰居的對(duì)比學(xué)習(xí)
語(yǔ)義鄰居是指圖上無(wú)法到達(dá)但具有相似特征(商品節(jié)點(diǎn))或偏好(用戶節(jié)點(diǎn))的節(jié)點(diǎn)。這部分通過(guò)聚類,將相似embedding對(duì)應(yīng)的節(jié)點(diǎn)劃分的相同的簇,用簇中心代表這個(gè)簇,這個(gè)中心稱為原型。由于該過(guò)程無(wú)法進(jìn)行端到端優(yōu)化,使用 EM 算法學(xué)習(xí)提出的原型對(duì)比目標(biāo)。聚類中GNN模型的目標(biāo)是最大化下式(用戶相關(guān)),簡(jiǎn)單理解就是讓用戶embedding劃分到某個(gè)簇,其中θ為可學(xué)習(xí)參數(shù),R為交互矩陣,c是用戶u的潛在原型。同理也可以得到商品相關(guān)的目標(biāo)式。
?提出的原型對(duì)比學(xué)習(xí)目標(biāo)基于InfoNCE來(lái)最小化以下函數(shù):
?最終的原型對(duì)比目標(biāo)是用戶目標(biāo)和項(xiàng)目目標(biāo)的加權(quán)和:
?4、優(yōu)化器
將提出的兩個(gè)對(duì)比學(xué)習(xí)損失作為補(bǔ)充,并利用多任務(wù)學(xué)習(xí)策略來(lái)聯(lián)合訓(xùn)練傳統(tǒng)的排序損失和提出的對(duì)比損失,公式如下,
?實(shí)驗(yàn)效果:
?三、pytoch代碼實(shí)現(xiàn)
1、GNN傳播部分
本質(zhì)就是lightGCN
def forward(self):
ego_embeddings = torch.cat([self.embedding_dict['user_emb'], self.embedding_dict['item_emb']], 0)
all_embeddings = [ego_embeddings]
for k in range(self.layers):
ego_embeddings = torch.sparse.mm(self.sparse_norm_adj, ego_embeddings)
all_embeddings += [ego_embeddings]
lgcn_all_embeddings = torch.stack(all_embeddings, dim=1)
lgcn_all_embeddings = torch.mean(lgcn_all_embeddings, dim=1)
user_all_embeddings = lgcn_all_embeddings[:self.data.user_num]
item_all_embeddings = lgcn_all_embeddings[self.data.user_num:]
return user_all_embeddings, item_all_embeddings, all_embeddings
輸出user_embedding、item_embedding、all_embedding (這個(gè)是存儲(chǔ)每層聚合的嵌入)
所對(duì)應(yīng)的是BPR_loss,如下:
rec_loss = bpr_loss(user_emb, pos_item_emb, neg_item_emb)
2、結(jié)構(gòu)鄰居的對(duì)比學(xué)習(xí)
initial_emb = emb_list[0] #初始embedding
context_emb = emb_list[self.hyper_layers*2] #對(duì)比偶數(shù)層
ssl_loss = self.ssl_layer_loss(context_emb,initial_emb,user_idx,pos_idx) #loss
看一下loss
def ssl_layer_loss(self, context_emb, initial_emb, user, item):
context_user_emb_all, context_item_emb_all = torch.split(context_emb, [self.data.user_num, self.data.item_num]) #拆分偶數(shù)層的嵌入 U+I
initial_user_emb_all, initial_item_emb_all = torch.split(initial_emb, [self.data.user_num, self.data.item_num]) #拆分初始的嵌入 U+I
context_user_emb = context_user_emb_all[user] #獲取當(dāng)前批次的嵌入
initial_user_emb = initial_user_emb_all[user]
# 對(duì)輸入數(shù)據(jù)進(jìn)行標(biāo)準(zhǔn)化使得輸入數(shù)據(jù)滿足正態(tài)分布
norm_user_emb1 = F.normalize(context_user_emb) #當(dāng)前偶數(shù)層批次
norm_user_emb2 = F.normalize(initial_user_emb) #當(dāng)前初始化批次
norm_all_user_emb = F.normalize(initial_user_emb_all)# 全部用戶
pos_score_user = torch.mul(norm_user_emb1, norm_user_emb2).sum(dim=1) # Zk * z0
ttl_score_user = torch.matmul(norm_user_emb1, norm_all_user_emb.transpose(0, 1))
pos_score_user = torch.exp(pos_score_user / self.ssl_temp) #分子
ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)#分母
ssl_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()
#item同理
context_item_emb = context_item_emb_all[item]
initial_item_emb = initial_item_emb_all[item]
norm_item_emb1 = F.normalize(context_item_emb)
norm_item_emb2 = F.normalize(initial_item_emb)
norm_all_item_emb = F.normalize(initial_item_emb_all)
pos_score_item = torch.mul(norm_item_emb1, norm_item_emb2).sum(dim=1)
ttl_score_item = torch.matmul(norm_item_emb1, norm_all_item_emb.transpose(0, 1))
pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)
ssl_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()
ssl_loss = self.ssl_reg * (ssl_loss_user + self.alpha * ssl_loss_item)
return ssl_loss
3、語(yǔ)義鄰居的對(duì)比學(xué)習(xí)
proto_loss = self.ProtoNCE_loss(initial_emb, user_idx, pos_idx)
def ProtoNCE_loss(self, initial_emb, user_idx, item_idx):
user_emb, item_emb = torch.split(initial_emb, [self.data.user_num, self.data.item_num])#拆分初始的嵌入 U+I
user2cluster = self.user_2cluster[user_idx]
user2centroids = self.user_centroids[user2cluster]
proto_nce_loss_user = InfoNCE(user_emb[user_idx],user2centroids,self.ssl_temp) * self.batch_size
item2cluster = self.item_2cluster[item_idx]
item2centroids = self.item_centroids[item2cluster]
proto_nce_loss_item = InfoNCE(item_emb[item_idx],item2centroids,self.ssl_temp) * self.batch_size
proto_nce_loss = self.proto_reg * (proto_nce_loss_user + proto_nce_loss_item)
return proto_nce_loss
總結(jié):
在這項(xiàng)工作中,提出了一種新的對(duì)比學(xué)習(xí)范式,稱為鄰域豐富的對(duì)比學(xué)習(xí)(NCL),以明確地將潛在的節(jié)點(diǎn)相關(guān)性捕獲到對(duì)比學(xué)習(xí)中,用于圖形協(xié)同過(guò)濾。分別從圖結(jié)構(gòu)和語(yǔ)義空間兩個(gè)方面考慮用戶(或項(xiàng)目)的鄰居。
首先,為了利用交互圖上的結(jié)構(gòu)鄰居,開(kāi)發(fā)了一個(gè)新的結(jié)構(gòu)對(duì)比目標(biāo),該目標(biāo)可以與基于GNN的協(xié)同過(guò)濾方法相結(jié)合。
其次,為了利用語(yǔ)義鄰域,通過(guò)對(duì)嵌入內(nèi)容進(jìn)行聚類,并將語(yǔ)義鄰域合并到原型對(duì)比目標(biāo)中,從而獲得用戶/項(xiàng)目的原型。對(duì)五個(gè)公共數(shù)據(jù)集的大量實(shí)驗(yàn)證明了所提出的NCL的有效性。
作為未來(lái)的工作,將把我們的框架擴(kuò)展到其他推薦任務(wù),例如順序推薦。此外,我們還將考慮制定一個(gè)更統(tǒng)一的方案,以利用和利用不同種類的鄰居。
原文鏈接:https://blog.csdn.net/zhao254014/article/details/126966039
相關(guān)推薦
- 2023-02-09 利用C++開(kāi)發(fā)一個(gè)protobuf動(dòng)態(tài)解析工具_(dá)C 語(yǔ)言
- 2022-03-14 Token跨域問(wèn)題Response to preflight request doesn‘t pas
- 2022-07-09 基于fluttertoast實(shí)現(xiàn)封裝彈框提示工具類_Android
- 2022-06-19 LINQ基礎(chǔ)之Join和UNION子句_C#教程
- 2023-02-06 C#實(shí)現(xiàn)將聊天數(shù)據(jù)發(fā)送加密_C#教程
- 2023-01-07 python導(dǎo)入其他目錄下模塊的四種情況_python
- 2022-11-07 Android?Framework如何實(shí)現(xiàn)Binder_Android
- 2023-01-01 C++日期和時(shí)間編程小結(jié)_C 語(yǔ)言
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支