日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

圖神經網絡GNN算法基本原理詳解_python

作者:Cyril_KI ? 更新時間: 2022-07-04 編程語言

前言

本文結合一個具體的無向圖來對最簡單的一種GNN進行推導。本文第一部分是數據介紹,第二部分為推導過程中需要用的變量的定義,第三部分是GNN的具體推導過程,最后一部分為自己對GNN的一些看法與總結。

1. 數據

利用networkx簡單生成一個無向圖:

# -*- coding: utf-8 -*-
"""
@Time : 2021/12/21 11:23
@Author :KI 
@File :gnn_basic.py
@Motto:Hungry And Humble
"""
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
G = nx.Graph()
node_features = [[2, 3], [4, 7], [3, 7], [4, 5], [5, 5]]
edges = [(1, 2), (1, 3), (2, 4), (2, 5), (1, 3), (3, 5), (3, 4)]
edge_features = [[1, 3], [4, 1], [1, 5], [5, 3], [5, 6], [5, 4], [4, 3]]
colors = []
edge_colors = []
# add nodes
for i in range(1, len(node_features) + 1):
    G.add_node(i, feature=str(i) + ':(' + str(node_features[i-1][0]) + ',' + str(node_features[i-1][1]) + ')')
    colors.append('#DCBB8A')
# add edges
for i in range(1, len(edge_features) + 1):
    G.add_edge(edges[i-1][0], edges[i-1][1], feature='(' + str(edge_features[i-1][0]) + ',' + str(edge_features[i-1][1]) + ')')
    edge_colors.append('#3CA9C4')
# draw
fig, ax = plt.subplots()
pos = nx.spring_layout(G)
nx.draw(G, pos=pos, node_size=2000, node_color=colors, edge_color='black')
node_labels = nx.get_node_attributes(G, 'feature')
nx.draw_networkx_labels(G, pos=pos, labels=node_labels, node_size=2000, node_color=colors, font_color='r', font_size=14)
edge_labels = nx.get_edge_attributes(G, 'feature')
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=14, font_color='#7E8877')
ax.set_facecolor('deepskyblue')
ax.axis('off')
fig.set_facecolor('deepskyblue')
plt.show()

如下所示:

其中,每一個節點都有自己的一些特征,比如在社交網絡中,每個節點(用戶)有性別以及年齡等特征。

5個節點的特征向量依次為:

[[2, 3], [4, 7], [3, 7], [4, 5], [5, 5]]

同樣,6條邊的特征向量為:

[[1, 3], [4, 1], [1, 5], [5, 3], [5, 6], [5, 4], [4, 3]]

2. 變量定義

特征向量實際上也就是節點或者邊的標簽,這個是圖本身的屬性,一直保持不變。

3. GNN算法

GNN算法的完整描述如下:Forward向前計算狀態,Backward向后計算梯度,主函數通過向前和向后迭代調用來最小化損失。

主函數中:

上述描述只是一個總體的概述,可以略過先不看。

3.1 Forward

早期的GNN都是RecGNN,即循環GNN。這種類型的GNN基于信息傳播機制: GNN通過不斷交換鄰域信息來更新節點狀態,直到達到穩定均衡。節點的狀態向量 x 由以下 f w ?函數來進行周期性更新:

?

解析上述公式:對于節點 n ,假設為節點1,更新其狀態需要以下數據參與:

這里的fw只是形式化的定義,不同的GNN有不同的定義,如隨機穩態嵌入(SSE)中定義如下:

由更新公式可知,當所有節點的狀態都趨于穩定狀態時,此時所有節點的狀態向量都包含了其鄰居節點和相連邊的信息。

這與圖嵌入有些類似:如果是節點嵌入,我們最終得到的是一個節點的向量表示,而這些向量是根據隨機游走序列得到的,隨機游走序列中又包括了節點的鄰居信息, 因此節點的向量表示中包含了連接信息。

證明上述更新過程能夠收斂需要用到不動點理論,這里簡單描述下:

如果我們有以下更新公式:

GNN的Foward描述如下:

解釋:

3.2 Backward

在節點嵌入中,我們最終得到了每個節點的表征向量,此時我們就能利用這些向量來進行聚類、節點分類、鏈接預測等等。

GNN中類似,得到這些節點狀態向量的最終形式不是我們的目的,我們的目的是利用這些節點狀態向量來做一些實際的應用,比如節點標簽預測。

因此,如果想要預測的話,我們就需要一個輸出函數來對節點狀態進行變換,得到我們要想要的東西:

最容易想到的就是將節點狀態向量經過一個前饋神經網絡得到輸出,也就是說 g w g_w gw?可以是一個FNN,同樣的, f w f_w fw?也可以是一個FNN:

我們利用 g w g_w gw?函數對節點 n n n收斂后的狀態向量 x n x_n xn?以及其特征向量 l n l_n ln?進行變換,就能得到我們想要的輸出,比如某一類別,某一具體的數值等等。

在BP算法中,我們有了輸出后,就能算出損失,然后利用損失反向傳播算出梯度,最后再利用梯度下降法對神經網絡的參數進行更新。

對于某一節點的損失(比如回歸)我們可以簡單定義如下:

有了z(t)后,我們就能求導了:

z(t)的求解方法在Backward中有描述:

因此,在Backward中需要計算以下導數:

4.總結與展望

本文所講的GNN是最原始的GNN,此時的GNN存在著不少的問題,比如對不動點隱藏狀態的更新比較低效。

由于CNN在CV領域的成功,許多重新定義圖形數據卷積概念的方法被提了出來,圖卷積神經網絡ConvGNN也被提了出來,ConvGNN被分為兩大類:頻域方法(spectral-based method )和空間域方法(spatial-based method)。2009年,Micheli在繼承了來自RecGNN的消息傳遞思想的同時,在架構上復合非遞歸層,首次解決了圖的相互依賴問題。在過去的幾年里還開發了許多替代GNN,包括GAE和STGNN。這些學習框架可以建立在RecGNN、ConvGNN或其他用于圖形建模的神經架構上。

GNN是用于圖數據的深度學習架構,它將端到端學習與歸納推理相結合,業界普遍認為其有望解決深度學習無法處理的因果推理、可解釋性等一系列瓶頸問題,是未來3到5年的重點方向。

因此,不僅僅是GNN,圖領域的相關研究都是比較有前景的,這方面的應用也十分廣泛,比如推薦系統、計算機視覺、物理/化學(生命科學)、藥物發現等等。

原文鏈接:https://blog.csdn.net/Cyril_KI/article/details/122058881

欄目分類
最近更新