網站首頁 編程語言 正文
Pytorch-Geometric中Message Passing使用
圖中的卷積計算通常被稱為鄰域聚合或者消息傳遞 (neighborhood aggregation or message passing).
定義為節點i在第(k?1)層的特征,ej,i表示節點j到節點i的邊特征,在GNN中消息傳遞可以表示為
其中 □ 表示具有置換不變性并且可微的函數,例如 sum, mean, max 等, γ 和 ? 表示可微函數。
在 PyTorch Gemetric 中,所有卷積算子都是由 MessagePassing
類派生而來,理解 MessagePasing
有助于我們理解 PyG 中消息傳遞的計算方式和編寫自定義的卷積。
在自定義卷積中,用戶只需定義消息傳遞函數 ? message()
, 節點更新函數 γ update()
以及聚合方式 aggr='add', aggr='mean'
或則 aggr=max
.
具體函數說明如下
-
MessagePassing(aggr='add', flow='source_to_target', node_dim=-2)
定義聚合計算的方式 ('add', 'mean'
ormax
) 以及消息的傳遞方向 (source_to_target
ortarget_to_source
). 在 PyG 中,中心節點為目標 target,鄰域節點為源 source.node_dim
為消息聚合的維度 -
MessagePassing.propagate(edge_index, size=None, **kwargs):
該函數接受邊信息edge_index
和其他額外的數據來執行消息傳遞并更新節點嵌入 -
MessagePassing.message(...):
該函數的作用是計算節點消息,就是公式中的函數 ? \phi ? . 如果flow='source_to_target'
,那么消息將由鄰域節點 j j j 傳向中心節點 i i i ;如果flow='target_to_source'
,消息則由中心節點 i i i 傳向鄰域節點 j j j . 傳入參數的節點類型可以通過變量名后綴來確定,例如中心節點嵌入變量一般以_i
為結尾,鄰域節點嵌入變量以x_j
為結尾 -
MessagePassing.update(arr_out, ...):
該函數為節點嵌入的更新函數 γ \gamma γ , 輸入參數為聚合函數MessagePassing.aggregate
計算的結果
為了更好的理解 PyG 中 MessagePassing
的計算過程,我們來分析一下源代碼。
class MessagePassing(torch.nn.Module): special_args: Set[str] = { 'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j', 'ptr', 'index', 'dim_size' } def __init__(self, aggr: Optional[str] = "add", flow: str = "source_to_target", node_dim: int = -2): super(MessagePassing, self).__init__() self.aggr = aggr assert self.aggr in ['add', 'mean', 'max', None] self.flow = flow assert self.flow in ['source_to_target', 'target_to_source'] self.node_dim = node_dim self.inspector = Inspector(self) self.inspector.inspect(self.message) self.inspector.inspect(self.aggregate, pop_first=True) self.inspector.inspect(self.message_and_aggregate, pop_first=True) self.inspector.inspect(self.update, pop_first=True) self.__user_args__ = self.inspector.keys( ['message', 'aggregate', 'update']).difference(self.special_args) self.__fused_user_args__ = self.inspector.keys( ['message_and_aggregate', 'update']).difference(self.special_args) # Support for "fused" message passing. self.fuse = self.inspector.implements('message_and_aggregate') # Support for GNNExplainer. self.__explain__ = False self.__edge_mask__ = None
在初始化函數中,MessagePassing
定義了一個 Inspector
. Inspector 的中文意思是檢查員的意思,這個類的作用就是檢查各個函數的輸入參數,并保存到 Inspector
的參數列表字典中 Inspector.params
中。
如果 message
的輸入參數為 x_i, x_j
,那么Inspector.params['message']={'x_i': Parameter, 'x_j': Parameter}
(注:這里僅作示意,實際 Inspector.params['message']
類型為 OrderedDict
). Inspector.implements
檢查函數是否實現.
MessagePasing
中最核心的是 propgate
函數,假設鄰接矩陣 edge_index
的類型為 Torch.LongTensor
,消息由 edge_index[0]
傳向 edge_index[1]
,代碼實現如下
def propagate(self, edge_index: Adj, size: Size = None, **kwargs): # 為了簡化問題,這里不討論 edge_index 為 SparseTensor 的情況,感興趣的可閱讀 PyG 原始代碼 size = self.__check_input__(edge_index, size) coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs) msg_kwargs = self.inspector.distribute('message', coll_dict) out = self.message(**msg_kwargs) aggr_kwargs = self.inspector.distribute('aggregate', coll_dict) out = self.aggregate(out, **aggr_kwargs) update_kwargs = self.inspector.distribute('update', coll_dict) return self.update(out, **update_kwargs)
在這段代碼中,首先是檢查節點數量和用戶自定義的輸入變量,然后依次執行 message
, aggregate
和 update
函數。
如果是自定義圖卷積,一般會重寫 message
和 update
,這一點隨后再以 GCN 為例解釋,這里首先來看一下 aggregate
的實現
def aggregate(self, inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None) -> Tensor: if ptr is not None: ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim()) return segment_csr(inputs, ptr, reduce=self.aggr) else: return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
ptr
變量是針對鄰接矩陣 edge_index
為 SparseTensor
的情況,此處暫且不論
inputs
為 message
計算得到的消息, index
就是待更新節點的索引,實際上就是 edge_index_i
. 聚合計算通過 scatter
函數實現。scatter
具體實現參考鏈接
下面以 GCN 為例,我們來看一下 MessagePassing
的計算過程。
GCN 的計算公式如下
實際計算工程可以分為下面幾步
- 1.在鄰接矩陣中增加自循環,即把鄰接矩陣的對角線上的元素設為1
- 2.對節點特征矩陣做線性變換
- 3.計算節點的歸一化系數,也就是節點度乘積的開方
- 4.對節點特征做歸一化處理
- 5.聚合(求和)節點特征得到新的節點嵌入
代碼如下
import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree class GCNConv(MessagePassing): def __init__(self, in_channels, out_channels): super(GCNConv, self).__init__(aggr='add') # "Add" aggregation (Step 5). self.lin = torch.nn.Linear(in_channels, out_channels) def forward(self, x, edge_index): # x has shape [N, in_channels] # edge_index has shape [2, E] # Step 1: Add self-loops to the adjacency matrix. edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # Step 2: Linearly transform node feature matrix. x = self.lin(x) # Step 3: Compute normalization. row, col = edge_index deg = degree(col, x.size(0), dtype=x.dtype) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # Step 4-5: Start propagating messages. return self.propagate(edge_index, x=x, norm=norm) def message(self, x_j, norm): # x_j has shape [E, out_channels] # Step 4: Normalize node features. return norm.view(-1, 1) * x_j
在 forward
函數中,首先是給節點邊增加自循環。設輸入變量如下
edge_index = torch.tensor([[0, 0, 2], [1, 2, 3]], dtype=torch.long) x = torch.rand((4, 3)) conv = GCNConv(3, 8)
注意到默認消息傳遞方向為 source_to_target
,此時edge_index[0]=x_j
為 source, edge_index[1]=x_i
為 target.
在 GCN 中,第一步是增加節點的自循環,add_self_loops
計算前后變化如下
# before add_self_loops # edge_index= tensor([[0, 0, 2], [1, 2, 3]]) # after add_self_loops # edge_index= tensor([[0, 0, 2, 0, 1, 2, 3], [1, 2, 3, 0, 1, 2, 3]]) # norm= tensor([0.7071, 0.7071, 0.5000, 1.0000, 0.5000, 0.5000, 0.5000]
此處的 propagate
的輸出參數由 edge_index, x, norm
, edge_index
是 propagete
必須輸入的參數,x, norm
為用戶自定義參數。
在 __collect__
會根據變量名稱來收集 message
需要的輸入參數。
在 GCN 中,norm
保持不變,x
將被映射到 x_j
,并且經過 __lift__
函數,其值也會發生變化。__lift__
函數如下
def __lift__(self, src, edge_index, dim): if isinstance(edge_index, Tensor): index = edge_index[dim] return src.index_select(self.node_dim, index)
在本例中,輸入的特征 shape=[4, 8]
,經過 __lift__
后,節點特征 shape=[7, 8]
. 經過 message
計算后,就可以執行 aggregate
和 update
了。
總結
原文鏈接:https://blog.csdn.net/morgan777/article/details/121183287
相關推薦
- 2023-12-08 table 單元格垂直居中
- 2022-07-11 為Spring配置文件的配置項添加元注釋
- 2022-09-29 ASP.NET?MVC實現多選下拉框保存并顯示_實用技巧
- 2022-03-14 使用npm安裝淘寶鏡像(npm配置淘寶鏡像)
- 2022-05-17 EdgeX 設備服務與core-data、core-command的交互
- 2022-11-12 Shell實現字符串處理的方法詳解_linux shell
- 2022-02-24 Golang?strings包常用字符串操作函數_Golang
- 2022-03-07 Go中defer使用場景及注意事項_Golang
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支