日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

Pytorch-Geometric中的Message?Passing使用及說明_python

作者:泊柴 ? 更新時間: 2023-01-17 編程語言

Pytorch-Geometric中Message Passing使用

圖中的卷積計算通常被稱為鄰域聚合或者消息傳遞 (neighborhood aggregation or message passing).

定義為節點i在第(k?1)層的特征,ej,i表示節點j到節點i的邊特征,在GNN中消息傳遞可以表示為

其中 □ 表示具有置換不變性并且可微的函數,例如 sum, mean, max 等, γ 和 ? 表示可微函數。

在 PyTorch Gemetric 中,所有卷積算子都是由 MessagePassing 類派生而來,理解 MessagePasing 有助于我們理解 PyG 中消息傳遞的計算方式和編寫自定義的卷積。

在自定義卷積中,用戶只需定義消息傳遞函數 ? message(), 節點更新函數 γ update() 以及聚合方式 aggr='add', aggr='mean' 或則 aggr=max.

具體函數說明如下

  • MessagePassing(aggr='add', flow='source_to_target', node_dim=-2) 定義聚合計算的方式 ('add', 'mean' or max ) 以及消息的傳遞方向 (source_to_target or target_to_source ). 在 PyG 中,中心節點為目標 target,鄰域節點為源 source. node_dim 為消息聚合的維度
  • MessagePassing.propagate(edge_index, size=None, **kwargs): 該函數接受邊信息 edge_index 和其他額外的數據來執行消息傳遞并更新節點嵌入
  • MessagePassing.message(...): 該函數的作用是計算節點消息,就是公式中的函數 ? \phi ? . 如果 flow='source_to_target' ,那么消息將由鄰域節點 j j j 傳向中心節點 i i i ;如果 flow='target_to_source',消息則由中心節點 i i i 傳向鄰域節點 j j j . 傳入參數的節點類型可以通過變量名后綴來確定,例如中心節點嵌入變量一般以 _i 為結尾,鄰域節點嵌入變量以 x_j 為結尾
  • MessagePassing.update(arr_out, ...): 該函數為節點嵌入的更新函數 γ \gamma γ , 輸入參數為聚合函數 MessagePassing.aggregate 計算的結果

為了更好的理解 PyG 中 MessagePassing 的計算過程,我們來分析一下源代碼。

class MessagePassing(torch.nn.Module):

    special_args: Set[str] = {
        'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',
        'size_i', 'size_j', 'ptr', 'index', 'dim_size'
    }

    def __init__(self, aggr: Optional[str] = "add",
                 flow: str = "source_to_target", node_dim: int = -2):

        super(MessagePassing, self).__init__()

        self.aggr = aggr
        assert self.aggr in ['add', 'mean', 'max', None]

        self.flow = flow
        assert self.flow in ['source_to_target', 'target_to_source']

        self.node_dim = node_dim

        self.inspector = Inspector(self)
        self.inspector.inspect(self.message)
        self.inspector.inspect(self.aggregate, pop_first=True)
        self.inspector.inspect(self.message_and_aggregate, pop_first=True)
        self.inspector.inspect(self.update, pop_first=True)

        self.__user_args__ = self.inspector.keys(
            ['message', 'aggregate', 'update']).difference(self.special_args)
        self.__fused_user_args__ = self.inspector.keys(
            ['message_and_aggregate', 'update']).difference(self.special_args)

        # Support for "fused" message passing.
        self.fuse = self.inspector.implements('message_and_aggregate')

        # Support for GNNExplainer.
        self.__explain__ = False
        self.__edge_mask__ = None

在初始化函數中,MessagePassing 定義了一個 Inspector . Inspector 的中文意思是檢查員的意思,這個類的作用就是檢查各個函數的輸入參數,并保存到 Inspector的參數列表字典中 Inspector.params中。

如果 message的輸入參數為 x_i, x_j,那么Inspector.params['message']={'x_i': Parameter, 'x_j': Parameter} (注:這里僅作示意,實際 Inspector.params['message'] 類型為 OrderedDict). Inspector.implements 檢查函數是否實現.

MessagePasing 中最核心的是 propgate 函數,假設鄰接矩陣 edge_index 的類型為 Torch.LongTensor,消息由 edge_index[0] 傳向 edge_index[1] ,代碼實現如下

def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
    # 為了簡化問題,這里不討論 edge_index 為 SparseTensor 的情況,感興趣的可閱讀 PyG 原始代碼

    size = self.__check_input__(edge_index, size)
    coll_dict = self.__collect__(self.__user_args__, edge_index, size,
                                 kwargs)

    msg_kwargs = self.inspector.distribute('message', coll_dict)
    out = self.message(**msg_kwargs)

    aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
    out = self.aggregate(out, **aggr_kwargs)

    update_kwargs = self.inspector.distribute('update', coll_dict)
    return self.update(out, **update_kwargs)        

在這段代碼中,首先是檢查節點數量和用戶自定義的輸入變量,然后依次執行 message, aggregateupdate 函數。

如果是自定義圖卷積,一般會重寫 messageupdate,這一點隨后再以 GCN 為例解釋,這里首先來看一下 aggregate 的實現

def aggregate(self, inputs: Tensor, index: Tensor,
              ptr: Optional[Tensor] = None,
              dim_size: Optional[int] = None) -> Tensor:
    if ptr is not None:
        ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
        return segment_csr(inputs, ptr, reduce=self.aggr)
    else:
        return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
                       reduce=self.aggr)

ptr 變量是針對鄰接矩陣 edge_indexSparseTensor的情況,此處暫且不論

inputsmessage計算得到的消息, index 就是待更新節點的索引,實際上就是 edge_index_i. 聚合計算通過 scatter 函數實現。scatter 具體實現參考鏈接

下面以 GCN 為例,我們來看一下 MessagePassing 的計算過程。

GCN 的計算公式如下

實際計算工程可以分為下面幾步

  • 1.在鄰接矩陣中增加自循環,即把鄰接矩陣的對角線上的元素設為1
  • 2.對節點特征矩陣做線性變換
  • 3.計算節點的歸一化系數,也就是節點度乘積的開方
  • 4.對節點特征做歸一化處理
  • 5.聚合(求和)節點特征得到新的節點嵌入

代碼如下

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

forward 函數中,首先是給節點邊增加自循環。設輸入變量如下

edge_index = torch.tensor([[0, 0, 2], [1, 2, 3]], dtype=torch.long)
x = torch.rand((4, 3)) 
conv = GCNConv(3, 8)

注意到默認消息傳遞方向為 source_to_target,此時edge_index[0]=x_j 為 source, edge_index[1]=x_i 為 target.

在 GCN 中,第一步是增加節點的自循環,add_self_loops 計算前后變化如下

# before add_self_loops
# edge_index=
tensor([[0, 0, 2],
        [1, 2, 3]])
# after add_self_loops
# edge_index=
tensor([[0, 0, 2, 0, 1, 2, 3],
        [1, 2, 3, 0, 1, 2, 3]])
# norm=
tensor([0.7071, 0.7071, 0.5000, 1.0000, 0.5000, 0.5000, 0.5000]

此處的 propagate 的輸出參數由 edge_index, x, norm , edge_indexpropagete 必須輸入的參數,x, norm 為用戶自定義參數。

__collect__ 會根據變量名稱來收集 message 需要的輸入參數。

在 GCN 中,norm 保持不變,x 將被映射到 x_j ,并且經過 __lift__ 函數,其值也會發生變化。__lift__ 函數如下

def __lift__(self, src, edge_index, dim):
    if isinstance(edge_index, Tensor):
        index = edge_index[dim]
        return src.index_select(self.node_dim, index)

在本例中,輸入的特征 shape=[4, 8],經過 __lift__ 后,節點特征 shape=[7, 8] . 經過 message 計算后,就可以執行 aggregateupdate 了。

總結

原文鏈接:https://blog.csdn.net/morgan777/article/details/121183287

欄目分類
最近更新