網站首頁 編程語言 正文
Step 1. 獲取混淆矩陣
#首先定義一個 分類數*分類數 的空混淆矩陣 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds) # 使用torch.no_grad()可以顯著降低測試用例的GPU占用 with torch.no_grad(): for step, (imgs, targets) in enumerate(test_loader): # imgs: torch.Size([50, 3, 200, 200]) torch.FloatTensor # targets: torch.Size([50, 1]), torch.LongTensor 多了一維,所以我們要把其去掉 targets = targets.squeeze() # [50,1] -----> [50] # 將變量轉為gpu targets = targets.cuda() imgs = imgs.cuda() # print(step,imgs.shape,imgs.type(),targets.shape,targets.type()) out = model(imgs) #記錄混淆矩陣參數 conf_matrix = confusion_matrix(out, targets, conf_matrix) conf_matrix=conf_matrix.cpu()
混淆矩陣的求取用到了confusion_matrix函數,其定義如下:
def confusion_matrix(preds, labels, conf_matrix): preds = torch.argmax(preds, 1) for p, t in zip(preds, labels): conf_matrix[p, t] += 1 return conf_matrix
在當我們的程序執行結束 test_loader 后,我們可以得到本次數據的 混淆矩陣,接下來就要計算其 識別正確的個數以及混淆矩陣可視化:
conf_matrix=np.array(conf_matrix.cpu())# 將混淆矩陣從gpu轉到cpu再轉到np corrects=conf_matrix.diagonal(offset=0)#抽取對角線的每種分類的識別正確個數 per_kinds=conf_matrix.sum(axis=1)#抽取每個分類數據總的測試條數 print("混淆矩陣總元素個數:{0},測試集總個數:{1}".format(int(np.sum(conf_matrix)),test_num)) print(conf_matrix) # 獲取每種Emotion的識別準確率 print("每種情感總個數:",per_kinds) print("每種情感預測正確的個數:",corrects) print("每種情感的識別準確率為:{0}".format([rate*100 for rate in corrects/per_kinds]))
執行此步的輸出結果如下所示:
Step 2. 混淆矩陣可視化
對上邊求得的混淆矩陣可視化
# 繪制混淆矩陣 Emotion=8#這個數值是具體的分類數,大家可以自行修改 labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每種類別的標簽 # 顯示數據 plt.imshow(conf_matrix, cmap=plt.cm.Blues) # 在圖中標注數量/概率信息 thresh = conf_matrix.max() / 2 #數值顏色閾值,如果數值超過這個,就顏色加深。 for x in range(Emotion_kinds): for y in range(Emotion_kinds): # 注意這里的matrix[y, x]不是matrix[x, y] info = int(conf_matrix[y, x]) plt.text(x, y, info, verticalalignment='center', horizontalalignment='center', color="white" if info > thresh else "black") plt.tight_layout()#保證圖不重疊 plt.yticks(range(Emotion_kinds), labels) plt.xticks(range(Emotion_kinds), labels,rotation=45)#X軸字體傾斜45° plt.show() plt.close()
好了,以下就是最終的可視化的混淆矩陣啦:
其它分類指標的獲取
例如 F1分數、TP、TN、FP、FN、精確率、召回率 等指標, 待補充哈(因為暫時還沒用到)~
總結
原文鏈接:https://blog.csdn.net/weixin_38468077/article/details/121671139
相關推薦
- 2022-11-21 如何使用正則表達式對輸入數字進行匹配詳解_正則表達式
- 2024-07-15 linux系統管理高級命令(練習)(six day)
- 2022-05-15 uniapp穿透第三方uView組件樣式
- 2022-12-22 python實現將list拼接為一個字符串_python
- 2023-01-29 React基于路由的代碼分割技術詳解_React
- 2022-05-02 C語言如何實現一些算法或者函數你知道嗎_C 語言
- 2022-02-02 element ui el-dialog 居中,并且內容多的時候內部可以滾動
- 2023-03-04 Golang設計模式之組合模式講解_Golang
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支