網(wǎng)站首頁 編程語言 正文
pytorch geometric的GNN、GCN節(jié)點(diǎn)分類
# -*- coding: utf-8 -*-
import os
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.datasets import Planetoid
import torch_geometric.nn as pyg_nn
import torch_geometric.transforms as T
# load dataset
def get_data(folder="node_classify/cora", data_name="cora"):
# dataset = Planetoid(root=folder, name=data_name)
dataset = Planetoid(root=folder, name=data_name,
transform=T.NormalizeFeatures())
return dataset
# create the graph cnn model
class GraphCNN(nn.Module):
def __init__(self, in_c, hid_c, out_c):
super(GraphCNN, self).__init__()
self.conv1 = pyg_nn.GCNConv(in_channels=in_c, out_channels=hid_c)
self.conv2 = pyg_nn.GCNConv(in_channels=hid_c, out_channels=out_c)
def forward(self, data):
# data.x data.edge_index
x = data.x # [N, C]
edge_index = data.edge_index # [2 ,E]
hid = self.conv1(x=x, edge_index=edge_index) # [N, D]
hid = F.relu(hid)
out = self.conv2(x=hid, edge_index=edge_index) # [N, out_c]
out = F.log_softmax(out, dim=1) # [N, out_c]
return out
class OwnGCN(nn.Module):
def __init__(self, in_c, hid_c, out_c):
super(OwnGCN, self).__init__()
self.in_ = pyg_nn.SGConv(in_c, hid_c, K=2)
self.conv1 = pyg_nn.APPNP(K=2, alpha=0.1)
self.conv2 = pyg_nn.APPNP(K=2, alpha=0.1)
self.out_ = pyg_nn.SGConv(hid_c, out_c, K=2)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.in_(x, edge_index)
x = F.dropout(x, p=0.1, training=self.training)
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.1, training=self.training)
x = F.relu(self.conv2(x, edge_index))
x = F.dropout(x, p=0.1, training=self.training)
x = self.out_(x, edge_index)
return F.log_softmax(x, dim=1)
# todo list
class YourOwnGCN(nn.Module):
pass
def analysis_data(dataset):
print("Basic Info: ", dataset[0])
print("# Nodes: ", dataset[0].num_nodes)
print("# Features: ", dataset[0].num_features)
print("# Edges: ", dataset[0].num_edges)
print("# Classes: ", dataset.num_classes)
print("# Train samples: ", dataset[0].train_mask.sum().item())
print("# Valid samples: ", dataset[0].val_mask.sum().item())
print("# Test samples: ", dataset[0].test_mask.sum().item())
print("Undirected: ", dataset[0].is_undirected())
def main():
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
cora_dataset = get_data()
# todo list
# my_net = GraphCNN(in_c=cora_dataset.num_features, hid_c=150, out_c=cora_dataset.num_classes)
my_net = OwnGCN(in_c=cora_dataset.num_features, hid_c=300, out_c=cora_dataset.num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
my_net = my_net.to(device)
data = cora_dataset[0].to(device)
optimizer = torch.optim.Adam(my_net.parameters(), lr=1e-2, weight_decay=1e-3)
"""
# model train
my_net.train()
for epoch in range(500):
optimizer.zero_grad()
output = my_net(data)
loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
_, prediction = output.max(dim=1)
valid_correct = prediction[data.val_mask].eq(data.y[data.val_mask]).sum().item()
valid_number = data.val_mask.sum().item()
valid_acc = valid_correct / valid_number
print("Epoch: {:03d}".format(epoch + 1), "Loss: {:.04f}".format(loss.item()),
"Valid Accuracy:: {:.4f}".format(valid_acc))
"""
# model test
my_net = torch.load("node_classify/best.pth")
my_net.eval()
_, prediction = my_net(data).max(dim=1)
target = data.y
test_correct = prediction[data.test_mask].eq(target[data.test_mask]).sum().item()
test_number = data.test_mask.sum().item()
train_correct = prediction[data.train_mask].eq(target[data.train_mask]).sum().item()
train_number = data.train_mask.sum().item()
print("==" * 20)
print("Accuracy of Train Samples: {:.04f}".format(train_correct / train_number))
print("Accuracy of Test Samples: {:.04f}".format(test_correct / test_number))
def test_main():
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cora_dataset = get_data()
data = cora_dataset[0].to(device)
my_net = torch.load("node_classify/best.pth")
my_net.eval()
_, prediction = my_net(data).max(dim=1)
target = data.y
test_correct = prediction[data.test_mask].eq(target[data.test_mask]).sum().item()
test_number = data.test_mask.sum().item()
train_correct = prediction[data.train_mask].eq(target[data.train_mask]).sum().item()
train_number = data.train_mask.sum().item()
print("==" * 20)
print("Accuracy of Train Samples: {:.04f}".format(train_correct / train_number))
print("Accuracy of Test Samples: {:.04f}".format(test_correct / test_number))
if __name__ == '__main__':
test_main()
# main()
# dataset = get_data()
# analysis_data(dataset)
pytorch下GCN代碼解讀
def main():
print("hello world")
main()
import os.path as osp
import argparse
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, ChebConv # noqa
#GCN用于提取圖的特征參數(shù)然后用于分類
#數(shù)據(jù)集初始化部分
parser = argparse.ArgumentParser()
parser.add_argument('--use_gdc', action='store_true',
help='Use GDC preprocessing.')
args = parser.parse_args()#是否使用GDC優(yōu)化
dataset = 'CiteSeer'#訓(xùn)練用的數(shù)據(jù)集
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)#數(shù)據(jù)集存放位置
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())#數(shù)據(jù)初始化類,其dataset的基類是一個(gè)torch.utils.data.Dataset對象
data = dataset[0]#只有一個(gè)圖作為訓(xùn)練數(shù)據(jù)
#print(data)
#預(yù)處理和模型定義
if args.use_gdc:
gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
normalization_out='col',
diffusion_kwargs=dict(method='ppr', alpha=0.05),
sparsification_kwargs=dict(method='topk', k=128,
dim=0), exact=True)
data = gdc(data)#圖擴(kuò)散卷積用于預(yù)處理
#搭建模型
class Net(torch.nn.Module):
#放置參數(shù)層(一般為可學(xué)習(xí)層,不可學(xué)習(xí)層也可放置,若不放置,則在forward中用functional實(shí)現(xiàn))
def __init__(self):
super(Net, self).__init__()#在不覆蓋Module的Init函數(shù)的情況下設(shè)置Net的init函數(shù)
self.conv1 = GCNConv(dataset.num_features, 16, cached=True,
normalize=not args.use_gdc)#第一層GCN卷積運(yùn)算輸入特征向量大小為num_features輸出大小為16
#GCNConv的def init需要in_channel和out_channel(卷積核的數(shù)量)的參數(shù),并對in_channel和out_channel調(diào)用linear函數(shù),而該函數(shù)的作用為構(gòu)建全連接層
self.conv2 = GCNConv(16, dataset.num_classes, cached=True,
normalize=not args.use_gdc)#第二層GCN卷積運(yùn)算輸入為16(第一層的輸出)輸出為num_class
# self.conv1 = ChebConv(data.num_features, 16, K=2)
# self.conv2 = ChebConv(16, data.num_features, K=2)
#實(shí)現(xiàn)模型的功能各個(gè)層之間的連接關(guān)系(具體實(shí)現(xiàn))
def forward(self):
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr#賦值data.x特征向量edge_index圖的形狀,edge_attr權(quán)重矩陣
x = F.relu(self.conv1(x, edge_index, edge_weight))#第一層用非線性激活函數(shù)relu
#x,edge_index,edge_weight特征矩陣,鄰接矩陣,權(quán)重矩陣組成GCN核心公式
x = F.dropout(x, training=self.training)#用dropout函數(shù)防止過擬合
x = self.conv2(x, edge_index, edge_weight)#第二層輸出
return F.log_softmax(x, dim=1)#log_softmax激活函數(shù)用于最后一層返回分類結(jié)果
#Z=log_softmax(A*RELU(A*X*W0)*W1)A連接關(guān)系X特征矩陣W參數(shù)矩陣
#得到Z后即可用于分類
#softmax(dim=1)行和為1再取log x為節(jié)點(diǎn)的embedding
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#指定設(shè)備
model, data = Net().to(device), data.to(device)#copy model,data到device上
#優(yōu)化算法
optimizer = torch.optim.Adam([
dict(params=model.conv1.parameters(), weight_decay=5e-4),#權(quán)重衰減避免過擬合
dict(params=model.conv2.parameters(), weight_decay=0)#需要優(yōu)化的參數(shù)
], lr=0.01) # Only perform weight-decay on first convolution.
#lr步長因子或者是學(xué)習(xí)率
#模型訓(xùn)練
def train():
model.train()#設(shè)置成train模式
optimizer.zero_grad()#清空所有被優(yōu)化的變量的梯度
F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()#損失函數(shù)訓(xùn)練參數(shù)用于節(jié)點(diǎn)分類
optimizer.step()#步長
@torch.no_grad()#不需要計(jì)算梯度,也不進(jìn)行反向傳播
#測試
def test():
model.eval()#設(shè)置成evaluation模式
logits, accs = model(), []
for _, mask in data('train_mask', 'val_mask', 'test_mask'):#mask矩陣,掩膜作用與之相應(yīng)的部分不會被更新
pred = logits[mask].max(1)[1]#mask對應(yīng)點(diǎn)的輸出向量最大值并取序號1
acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()#判斷是否相等計(jì)算準(zhǔn)確度
accs.append(acc)
return accs
best_val_acc = test_acc = 0
#執(zhí)行
for epoch in range(1, 201):
train()
train_acc, val_acc, tmp_test_acc = test()#訓(xùn)練準(zhǔn)確率,實(shí)際輸入的準(zhǔn)確率,測試準(zhǔn)確率
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'#類型及保留位數(shù)
print(log.format(epoch, train_acc, best_val_acc, test_acc))#輸出格式化函數(shù)'''
總結(jié)
原文鏈接:https://blog.csdn.net/qq_38574975/article/details/107443725
相關(guān)推薦
- 2022-09-10 Python入門之模塊和包用法詳解_python
- 2021-12-03 Apache?Log4j2?報(bào)核彈級漏洞快速修復(fù)方法_Linux
- 2022-04-08 從頭學(xué)習(xí)C語言之字符串處理函數(shù)_C 語言
- 2022-07-17 使用非root用戶安裝及啟動docker的問題(rootless模式運(yùn)行)_docker
- 2023-11-13 Linux Ubuntu修改用戶名和主機(jī)名
- 2022-07-13 redis搭建哨兵集群的實(shí)現(xiàn)步驟_Redis
- 2022-03-15 更新Android Studio 4.0 啟動模擬器提示 unable to locate adb
- 2022-11-10 C語言strlen函數(shù)全方位講解_C 語言
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支