網站首頁 編程語言 正文
前言
在上一篇文章PyG搭建GCN前的準備:了解PyG中的數據格式中,大致了解了PyG中的數據格式,這篇文章主要是簡單搭建GCN來實現節點分類,主要目的是了解PyG中GCN的參數情況。
模型搭建
首先導入包:
from torch_geometric.nn import GCNConv
模型參數:
in_channels:輸入通道,比如節點分類中表示每個節點的特征數。
out_channels:輸出通道,最后一層GCNConv的輸出通道為節點類別數(節點分類)。
improved:如果為True表示自環增加,也就是原始鄰接矩陣加上2I而不是I,默認為False。
cached:如果為True,GCNConv在第一次對鄰接矩陣進行歸一化時會進行緩存,以后將不再重復計算。
add_self_loops:如果為False不再強制添加自環,默認為True。
normalize:默認為True,表示對鄰接矩陣進行歸一化。
bias:默認添加偏置。
于是模型搭建如下:
class GCN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_node_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = F.softmax(x, dim=1)
return x
輸出一下模型:
data = Planetoid(root='/data/CiteSeer', name='CiteSeer')model = GCN(data.num_node_features, data.num_classes).to(device)print(model)GCN(
(conv1): GCNConv(3703, 16)
(conv2): GCNConv(16, 6)
)
輸出為:
GCN( (conv1): GCNConv(3703, 16) (conv2): GCNConv(16, 6))GCN(
(conv1): GCNConv(3703, 16)
(conv2): GCNConv(16, 6)
)
1. 前向傳播
查看官方文檔中GCNConv的輸入輸出要求:
可以發現,GCNConv中需要輸入的是節點特征矩陣x和鄰接關系edge_index,還有一個可選項edge_weight。因此我們首先:
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
此時我們不妨輸出一下x及其size:
tensor([[0.0000, 0.1630, 0.0000, ..., 0.0000, 0.0488, 0.0000],
[0.0000, 0.2451, 0.1614, ..., 0.0000, 0.0125, 0.0000],
[0.1175, 0.0262, 0.2141, ..., 0.2592, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.1825, 0.0000],
[0.0000, 0.1024, 0.0000, ..., 0.0498, 0.0000, 0.0000],
[0.0000, 0.3263, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
device='cuda:0', grad_fn=<FusedDropoutBackward0>)
torch.Size([3327, 16])
此時的x一共3327行,每一行表示一個節點經過第一層卷積更新后的狀態向量。
那么同理,由于:
self.conv2 = GCNConv(16, num_classes)
所以經過第二層卷積后:
x = self.conv2(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
此時得到的x的size應該為:
torch.Size([3327, 6])
即每個節點的維度為6的狀態向量。
由于我們需要進行6分類,所以最后需要加上一個softmax:
x = F.softmax(x, dim=1)
dim=1表示對每一行進行運算,最終每一行之和加起來為1,也就表示了該節點為每一類的概率。輸出此時的x:
tensor([[0.1607, 0.1727, 0.1607, 0.1607, 0.1607, 0.1846], [0.1654, 0.1654, 0.1654, 0.1654, 0.1654, 0.1731], [0.1778, 0.1622, 0.1733, 0.1622, 0.1622, 0.1622], ..., [0.1659, 0.1659, 0.1659, 0.1704, 0.1659, 0.1659], [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667], [0.1641, 0.1641, 0.1658, 0.1766, 0.1653, 0.1641]], device='cuda:0', grad_fn=<SoftmaxBackward0>)tensor([[0.1607, 0.1727, 0.1607, 0.1607, 0.1607, 0.1846],
[0.1654, 0.1654, 0.1654, 0.1654, 0.1654, 0.1731],
[0.1778, 0.1622, 0.1733, 0.1622, 0.1622, 0.1622],
...,
[0.1659, 0.1659, 0.1659, 0.1704, 0.1659, 0.1659],
[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
[0.1641, 0.1641, 0.1658, 0.1766, 0.1653, 0.1641]], device='cuda:0',
grad_fn=<SoftmaxBackward0>)
2. 反向傳播
在訓練時,我們首先利用前向傳播計算出輸出:
out = model(data)
out即為最終得到的每個節點的6個概率值,但在實際訓練中,我們只需要計算出訓練集的損失,所以損失函數這樣寫:
loss = loss_function(out[data.train_mask], data.y[data.train_mask])
然后計算梯度,反向更新!
3. 訓練
訓練的完整代碼:
def train(): optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) loss_function = torch.nn.CrossEntropyLoss().to(device) model.train() for epoch in range(500): out = model(data) optimizer.zero_grad() loss = loss_function(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() print('Epoch {:03d} loss {:.4f}'.format(epoch, loss.item()))def train():
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
loss_function = torch.nn.CrossEntropyLoss().to(device)
model.train()
for epoch in range(500):
out = model(data)
optimizer.zero_grad()
loss = loss_function(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
print('Epoch {:03d} loss {:.4f}'.format(epoch, loss.item()))
4. 測試
我們首先需要算出模型對所有節點的預測值:
model(data)
此時得到的是每個節點的6個概率值,我們需要在每一行上取其最大值:
model(data).max(dim=1)
輸出一下:
torch.return_types.max(
values=tensor([0.9100, 0.9071, 0.9786, ..., 0.4321, 0.4009, 0.8779], device='cuda:0',
grad_fn=<MaxBackward0>),
indices=tensor([3, 1, 5, ..., 3, 1, 5], device='cuda:0'))
返回的第一項是每一行的最大值,第二項為最大值在這一行中的索引,我們只需要取第二項,那么最終的預測值應該寫為:
_, pred = model(data).max(dim=1)
然后計算預測精度:
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / int(data.test_mask.sum())
print('GCN Accuracy: {:.4f}'.format(acc))
完整代碼
完整代碼中實現了論文中提到的四種數據集,代碼地址:PyG-GCN。
原文鏈接:https://blog.csdn.net/Cyril_KI/article/details/123457698
相關推薦
- 2022-06-08 CentOs7下docker簡單實踐,安裝nginx
- 2022-08-14 淺談Python任務自動化工具Tox基本用法_python
- 2023-03-29 Python之sklearn數據預處理中fit(),transform()與fit_transfor
- 2022-05-27 C++回溯算法深度優先搜索舉例分析_C 語言
- 2024-04-02 docker開機自啟設置
- 2022-09-21 Flask深入了解Jinja2引擎的用法_python
- 2022-06-22 golang?API請求隊列的實現_Golang
- 2022-12-22 Object?arrays?cannot?be?loaded?when?allow_pickle=F
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支