網(wǎng)站首頁 編程語言 正文
Pytorch 多分類模型繪制 ROC, PR 曲線(代碼 親測 可用)
ROC曲線
示例代碼
import torch
import torch.nn as nn
import os
import numpy as np
from torchvision.datasets import ImageFolder
from utils.transform import get_transform_for_test
from senet.se_resnet import FineTuneSEResnet50
from scipy import interp
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn.metrics import roc_curve, auc, f1_score, precision_recall_curve, average_precision_score
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
data_root = r'D:\TJU\GBDB\set113\set113_images\test1' # 測試集路徑
test_weights_path = r"C:\Users\admin\Desktop\fsdownload\epoch_0278_top1_70.565_'checkpoint.pth.tar'" # 預(yù)訓(xùn)練模型參數(shù)
num_class = 113 # 類別數(shù)量
gpu = "cuda:0"
# mean=[0.948078, 0.93855226, 0.9332005], var=[0.14589554, 0.17054074, 0.18254866]
def test(model, test_path):
# 加載測試集和預(yù)訓(xùn)練模型參數(shù)
test_dir = os.path.join(data_root, 'test_images')
class_list = list(os.listdir(test_dir))
class_list.sort()
transform_test = get_transform_for_test(mean=[0.948078, 0.93855226, 0.9332005],
var=[0.14589554, 0.17054074, 0.18254866])
test_dataset = ImageFolder(test_dir, transform=transform_test)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=1, shuffle=False, drop_last=False, pin_memory=True, num_workers=1)
checkpoint = torch.load(test_path)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
score_list = [] # 存儲預(yù)測得分
label_list = [] # 存儲真實標簽
for i, (inputs, labels) in enumerate(test_loader):
inputs = inputs.cuda()
labels = labels.cuda()
outputs = model(inputs)
# prob_tmp = torch.nn.Softmax(dim=1)(outputs) # (batchsize, nclass)
score_tmp = outputs # (batchsize, nclass)
score_list.extend(score_tmp.detach().cpu().numpy())
label_list.extend(labels.cpu().numpy())
score_array = np.array(score_list)
# 將label轉(zhuǎn)換成onehot形式
label_tensor = torch.tensor(label_list)
label_tensor = label_tensor.reshape((label_tensor.shape[0], 1))
label_onehot = torch.zeros(label_tensor.shape[0], num_class)
label_onehot.scatter_(dim=1, index=label_tensor, value=1)
label_onehot = np.array(label_onehot)
print("score_array:", score_array.shape) # (batchsize, classnum)
print("label_onehot:", label_onehot.shape) # torch.Size([batchsize, classnum])
# 調(diào)用sklearn庫,計算每個類別對應(yīng)的fpr和tpr
fpr_dict = dict()
tpr_dict = dict()
roc_auc_dict = dict()
for i in range(num_class):
fpr_dict[i], tpr_dict[i], _ = roc_curve(label_onehot[:, i], score_array[:, i])
roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i])
# micro
fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve(label_onehot.ravel(), score_array.ravel())
roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"])
# macro
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(num_class)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(num_class):
mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i])
# Finally average it and compute AUC
mean_tpr /= num_class
fpr_dict["macro"] = all_fpr
tpr_dict["macro"] = mean_tpr
roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"])
# 繪制所有類別平均的roc曲線
plt.figure()
lw = 2
plt.plot(fpr_dict["micro"], tpr_dict["micro"],
label='micro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc_dict["micro"]),
color='deeppink', linestyle=':', linewidth=4)
plt.plot(fpr_dict["macro"], tpr_dict["macro"],
label='macro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc_dict["macro"]),
color='navy', linestyle=':', linewidth=4)
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(num_class), colors):
plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw,
label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, roc_auc_dict[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.savefig('set113_roc.jpg')
plt.show()
if __name__ == '__main__':
# 加載模型
seresnet = FineTuneSEResnet50(num_class=num_class)
device = torch.device(gpu)
seresnet = seresnet.to(device)
test(seresnet, test_weights_path)
運行結(jié)果:
PR曲線
示例代碼
import torch
import torch.nn as nn
import os
import numpy as np
from torchvision.datasets import ImageFolder
from utils.transform import get_transform_for_test
from senet.se_resnet import FineTuneSEResnet50
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, f1_score, precision_recall_curve, average_precision_score
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
data_root = r'D:\TJU\GBDB\set113\set113_images\test1' # 測試集路徑
test_weights_path = r"C:\Users\admin\Desktop\fsdownload\epoch_0278_top1_70.565_'checkpoint.pth.tar'" # 預(yù)訓(xùn)練模型參數(shù)
num_class = 113 # 類別數(shù)量
gpu = "cuda:0"
# mean=[0.948078, 0.93855226, 0.9332005], var=[0.14589554, 0.17054074, 0.18254866]
def test(model, test_path):
# 加載測試集和預(yù)訓(xùn)練模型參數(shù)
test_dir = os.path.join(data_root, 'test_images')
class_list = list(os.listdir(test_dir))
class_list.sort()
transform_test = get_transform_for_test(mean=[0.948078, 0.93855226, 0.9332005],
var=[0.14589554, 0.17054074, 0.18254866])
test_dataset = ImageFolder(test_dir, transform=transform_test)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=1, shuffle=False, drop_last=False, pin_memory=True, num_workers=1)
checkpoint = torch.load(test_path)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
score_list = [] # 存儲預(yù)測得分
label_list = [] # 存儲真實標簽
for i, (inputs, labels) in enumerate(test_loader):
inputs = inputs.cuda()
labels = labels.cuda()
outputs = model(inputs)
# prob_tmp = torch.nn.Softmax(dim=1)(outputs) # (batchsize, nclass)
score_tmp = outputs # (batchsize, nclass)
score_list.extend(score_tmp.detach().cpu().numpy())
label_list.extend(labels.cpu().numpy())
score_array = np.array(score_list)
# 將label轉(zhuǎn)換成onehot形式
label_tensor = torch.tensor(label_list)
label_tensor = label_tensor.reshape((label_tensor.shape[0], 1))
label_onehot = torch.zeros(label_tensor.shape[0], num_class)
label_onehot.scatter_(dim=1, index=label_tensor, value=1)
label_onehot = np.array(label_onehot)
print("score_array:", score_array.shape) # (batchsize, classnum) softmax
print("label_onehot:", label_onehot.shape) # torch.Size([batchsize, classnum]) onehot
# 調(diào)用sklearn庫,計算每個類別對應(yīng)的precision和recall
precision_dict = dict()
recall_dict = dict()
average_precision_dict = dict()
for i in range(num_class):
precision_dict[i], recall_dict[i], _ = precision_recall_curve(label_onehot[:, i], score_array[:, i])
average_precision_dict[i] = average_precision_score(label_onehot[:, i], score_array[:, i])
print(precision_dict[i].shape, recall_dict[i].shape, average_precision_dict[i])
# micro
precision_dict["micro"], recall_dict["micro"], _ = precision_recall_curve(label_onehot.ravel(),
score_array.ravel())
average_precision_dict["micro"] = average_precision_score(label_onehot, score_array, average="micro")
print('Average precision score, micro-averaged over all classes: {0:0.2f}'.format(average_precision_dict["micro"]))
# 繪制所有類別平均的pr曲線
plt.figure()
plt.step(recall_dict['micro'], precision_dict['micro'], where='post')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title(
'Average precision score, micro-averaged over all classes: AP={0:0.2f}'
.format(average_precision_dict["micro"]))
plt.savefig("set113_pr_curve.jpg")
# plt.show()
if __name__ == '__main__':
# 加載模型
seresnet = FineTuneSEResnet50(num_class=num_class)
device = torch.device(gpu)
seresnet = seresnet.to(device)
test(seresnet, test_weights_path)
運行結(jié)果:
原文鏈接:https://blog.csdn.net/Vertira/article/details/128482515
相關(guān)推薦
- 2022-04-02 .Net使用SuperSocket框架實現(xiàn)WebSocket前端_實用技巧
- 2022-09-15 C語言實現(xiàn)學(xué)生成績管理系統(tǒng)課程設(shè)計_C 語言
- 2022-12-30 React淺析Fragments使用方法_React
- 2022-03-04 Tue Dec 01 00:00:00 GMT+08:00 1998 轉(zhuǎn)成自定義字符串
- 2022-08-28 樹莓派設(shè)置wifi自動連接
- 2022-10-07 VsCode使用EmmyLua插件調(diào)試Unity工程Lua代碼的詳細步驟_C#教程
- 2022-12-05 Linux系統(tǒng)查看服務(wù)器帶寬及網(wǎng)絡(luò)使用情況的具體方法_服務(wù)器其它
- 2022-09-23 python?pandas創(chuàng)建多層索引MultiIndex的6種方式_python
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支