網站首頁 編程語言 正文
引言
使用 python 繪制網絡訓練過程中的的 loss 曲線以及準確率變化曲線,這里的主要思想就時先把想要的損失值以及準確率值保存下來,保存到 .txt 文件中,待網絡訓練結束,我們再拿這存儲的數據繪制各種曲線。
其大致步驟為:數據讀取與存儲 - > loss曲線繪制 - > 準確率曲線繪制
一、數據讀取與存儲部分
我們首先要得到訓練時的數據,以損失值為例,網絡每迭代一次都會產生相應的 loss,那么我們就把每一次的損失值都存儲下來,存儲到列表,保存到 .txt 文件中。保存的文件如下圖所示:
[1.3817585706710815, 1.8422836065292358, 1.1619832515716553, 0.5217241644859314, 0.5221078991889954, 1.3544578552246094, 1.3334463834762573, 1.3866571187973022, 0.7603049278259277]
上圖為部分損失值,根據迭代次數而異,要是迭代了1萬次,這里就會有1萬個損失值。
而準確率值是每一個 epoch 產生一個值,要是訓練100個epoch,就有100個準確率值。
(那么問題來了,這里的損失值是怎么保存到文件中的呢? 很少有人講這個,也有一些小伙伴們來咨詢,這里就統一記錄一下,包括損失值和準確率值。)
首先,找到網絡訓練代碼,就是項目中的 main.py,或者 train.py ,在文件里先找到訓練部分,里面經常會有這樣一行代碼:
for epoch in range(resume_epoch, num_epochs): # 就是這一行
####
...
loss = criterion(outputs, labels.long()) # 損失樣例
...
epoch_acc = running_corrects.double() / trainval_sizes[phase] # 準確率樣例
...
###
從這一行開始就是訓練部分了,往下會找到類似的這兩句代碼,就是損失值和準確率值了。
這時候將以下代碼加入源代碼就可以了:
train_loss = []
train_acc = []
for epoch in range(resume_epoch, num_epochs): # 就是這一行
###
...
loss = criterion(outputs, labels.long()) # 損失樣例
train_loss.append(loss.item()) # 損失加入到列表中
...
epoch_acc = running_corrects.double() / trainval_sizes[phase] # 準確率樣例
train_acc.append(epoch_acc.item()) # 準確率加入到列表中
...
with open("./train_loss.txt", 'w') as train_los:
train_los.write(str(train_loss))
with open("./train_acc.txt", 'w') as train_ac:
train_ac.write(str(train_acc))
這樣就算完成了損失值和準確率值的數據存儲了!
二、繪制 loss 曲線
主要需要 numpy 庫和 matplotlib 庫,如果不會安裝可以自行百度,很簡單。
首先,將 .txt 文件中的存儲的數據讀取進來,以下是讀取函數:
import numpy as np
# 讀取存儲為txt文件的數據
def data_read(dir_path):
with open(dir_path, "r") as f:
raw_data = f.read()
data = raw_data[1:-1].split(", ") # [-1:1]是為了去除文件中的前后中括號"[]"
return np.asfarray(data, float)
然后,就是繪制 loss 曲線部分:
if __name__ == "__main__":
train_loss_path = r"E:\relate_code\Gaitpart-master\train_loss.txt" # 存儲文件路徑
y_train_loss = data_read(train_loss_path) # loss值,即y軸
x_train_loss = range(len(y_train_loss)) # loss的數量,即x軸
plt.figure()
# 去除頂部和右邊框框
ax = plt.axes()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel('iters') # x軸標簽
plt.ylabel('loss') # y軸標簽
# 以x_train_loss為橫坐標,y_train_loss為縱坐標,曲線寬度為1,實線,增加標簽,訓練損失,
# 默認顏色,如果想更改顏色,可以增加參數color='red',這是紅色。
plt.plot(x_train_loss, y_train_loss, linewidth=1, linestyle="solid", label="train loss")
plt.legend()
plt.title('Loss curve')
plt.show()
這樣就算把損失圖像畫出來了!如下:
三、繪制準確率曲線
有了上面的基礎,這就簡單很多了。
只是有一點要記住,上面的x軸是迭代次數,這里的是訓練輪次 epoch。
if __name__ == "__main__":
train_acc_path = r"E:\relate_code\Gaitpart-master\train_acc.txt" # 存儲文件路徑
y_train_acc = data_read(train_acc_path) # 訓練準確率值,即y軸
x_train_acc = range(len(y_train_acc)) # 訓練階段準確率的數量,即x軸
plt.figure()
# 去除頂部和右邊框框
ax = plt.axes()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel('epochs') # x軸標簽
plt.ylabel('accuracy') # y軸標簽
# 以x_train_acc為橫坐標,y_train_acc為縱坐標,曲線寬度為1,實線,增加標簽,訓練損失,
# 增加參數color='red',這是紅色。
plt.plot(x_train_acc, y_train_acc, color='red',linewidth=1, linestyle="solid", label="train acc")
plt.legend()
plt.title('Accuracy curve')
plt.show()
這樣就把準確率變化曲線畫出來了!如下:
以下是完整代碼,以繪制準確率曲線為例,并且將x軸換成了iters,和損失曲線保持一致,供參考:
import numpy as np
import matplotlib.pyplot as plt
# 讀取存儲為txt文件的數據
def data_read(dir_path):
with open(dir_path, "r") as f:
raw_data = f.read()
data = raw_data[1:-1].split(", ")
return np.asfarray(data, float)
# 不同長度數據,統一為一個標準,倍乘x軸
def multiple_equal(x, y):
x_len = len(x)
y_len = len(y)
times = x_len/y_len
y_times = [i * times for i in y]
return y_times
if __name__ == "__main__":
train_loss_path = r"E:\relate_code\Gaitpart-master\file_txt\train_loss.txt"
train_acc_path = r"E:\relate_code\Gaitpart-master\train_acc.txt"
y_train_loss = data_read(train_loss_path)
y_train_acc = data_read(train_acc_path)
x_train_loss = range(len(y_train_loss))
x_train_acc = multiple_equal(x_train_loss, range(len(y_train_acc)))
plt.figure()
# 去除頂部和右邊框框
ax = plt.axes()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel('iters')
plt.ylabel('accuracy')
# plt.plot(x_train_loss, y_train_loss, linewidth=1, linestyle="solid", label="train loss")
plt.plot(x_train_acc, y_train_acc, color='red', linestyle="solid", label="train accuracy")
plt.legend()
plt.title('Accuracy curve')
plt.show()
總結
原文鏈接:https://blog.csdn.net/WYKB_Mr_Q/article/details/125661871
相關推薦
- 2022-06-09 Python中文分詞庫jieba(結巴分詞)詳細使用介紹_python
- 2023-12-20 Git同時配置Gitee和GitHub
- 2022-07-28 Python?Flask實現圖片上傳與下載的示例詳解_python
- 2024-01-27 解決“該項目不在請確認該項目位置,然后重試” 文件無法刪除問題
- 2023-02-23 GoLang中的互斥鎖Mutex和讀寫鎖RWMutex使用教程_Golang
- 2022-12-04 React18?useState何時執行更新及微任務理解_React
- 2022-08-10 pandas溫差查詢案例的實現_python
- 2022-11-15 Apache?Doris的Bitmap索引和BloomFilter索引使用及注意事項_Linux
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支