網站首頁 編程語言 正文
Seaborn - 繪制多標簽的混淆矩陣、召回、精準、F1
導入seaborn\matplotlib\scipy\sklearn等包:
import seaborn as sns
from matplotlib import pyplot as plt
from scipy.special import softmax
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score
sns.set_theme(color_codes=True)
從dataframe中,獲取y_true(真實標簽)和y_pred(預測標簽):
y_true = df["target"]
y_pred = df['prediction']
計算驗證數據整體的準確率acc、精準率precision、召回率recall、F1,使用加權模式average=‘weighted’:
# 準確率acc,精準precision,召回recall,F1
acc = accuracy_score(df["target"], df['prediction'])
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')
print(f'[Info] acc: {acc}, precision: {precision}, recall: {recall}, f1: {f1}')
計算混淆矩陣:
# 橫坐標是真實類別數,縱坐標是預測類別數
cf_matrix = confusion_matrix(y_true, y_pred)
5類矩陣的繪制方案,混淆矩陣、百分比的混淆矩陣、召回矩陣、精準矩陣、F1矩陣:
- 混淆矩陣是計數,百分比的混淆矩陣是占比
- 召回矩陣是,每行的和是1,每行代表真實類別數,占比就是召回
- 精準矩陣是,每列的和是1,每列代表預測列表數,占比就是精準
- F1矩陣是按照 2PR/(P+R),注意為0的情況,需要補0,使用np.divide(a, b, out=np.zeros_like(a), where=(b != 0))
代碼如下:
# 橫坐標是真實類別數,縱坐標是預測類別數
cf_matrix = confusion_matrix(y_true, y_pred)
figure, axes = plt.subplots(2, 2, figsize=(16*1.25, 16))
# 混淆矩陣
ax = sns.heatmap(cf_matrix, annot=True, fmt='g', ax=axes[0][0], cmap='Blues')
ax.title.set_text("Confusion Matrix")
ax.set_xlabel("y_pred")
ax.set_ylabel("y_true")
# plt.savefig(csv_path.replace(".csv", "_cf_matrix.png"))
# plt.show()
# 混淆矩陣 - 百分比
cf_matrix = confusion_matrix(y_true, y_pred)
ax = sns.heatmap(cf_matrix / np.sum(cf_matrix), annot=True, ax=axes[0][1], fmt='.2%', cmap='Blues')
ax.title.set_text("Confusion Matrix (percent)")
ax.set_xlabel("y_pred")
ax.set_ylabel("y_true")
# plt.savefig(csv_path.replace(".csv", "_cf_matrix_p.png"))
# plt.show()
# 召回矩陣,行和為1
sum_true = np.expand_dims(np.sum(cf_matrix, axis=1), axis=1)
precision_matrix = cf_matrix / sum_true
ax = sns.heatmap(precision_matrix, annot=True, fmt='.2%', ax=axes[1][0], cmap='Blues')
ax.title.set_text("Precision Matrix")
ax.set_xlabel("y_pred")
ax.set_ylabel("y_true")
# plt.savefig(csv_path.replace(".csv", "_recall.png"))
# plt.show()
# 精準矩陣,列和為1
sum_pred = np.expand_dims(np.sum(cf_matrix, axis=0), axis=0)
recall_matrix = cf_matrix / sum_pred
ax = sns.heatmap(recall_matrix, annot=True, fmt='.2%', ax=axes[1][1], cmap='Blues')
ax.title.set_text("Recall Matrix")
ax.set_xlabel("y_pred")
ax.set_ylabel("y_true")
# plt.savefig(csv_path.replace(".csv", "_precision.png"))
# plt.show()
# 繪制4張圖
plt.autoscale(enable=False)
plt.savefig(csv_path.replace(".csv", "_all.png"), bbox_inches='tight', pad_inches=0.2)
plt.show()
# F1矩陣
a = 2 * precision_matrix * recall_matrix
b = precision_matrix + recall_matrix
f1_matrix = np.divide(a, b, out=np.zeros_like(a), where=(b != 0))
ax = sns.heatmap(f1_matrix, annot=True, fmt='.2%', cmap='Blues')
ax.title.set_text("F1 Matrix")
ax.set_xlabel("y_pred")
ax.set_ylabel("y_true")
plt.savefig(csv_path.replace(".csv", "_f1.png"))
plt.show()
輸出混淆矩陣、混淆矩陣(百分比)、召回矩陣、精準矩陣:
F1 Score:
原文鏈接:https://blog.csdn.net/caroline_wendy/article/details/125796474
相關推薦
- 2022-06-28 Python利用shutil模塊實現文件的裁剪與壓縮_python
- 2023-07-05 DateUtils 日期工具類
- 2022-09-17 C++?中封裝的含義和簡單實現方式_C 語言
- 2022-04-01 python+selenium對table表和分頁處理_python
- 2022-06-07 Android音視頻開發之MediaPlayer使用教程_Android
- 2022-10-27 python中namedtuple函數的用法解析_python
- 2022-05-20 解決cnpm : 無法加載文件 C:\Users\Administrator\AppData\Roa
- 2022-12-29 react中將html字符串渲染到頁面的實現方式_React
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支