網站首頁 編程語言 正文
文章目錄
- 前言
- 模型訓練套路
- 1.準備數據集
- 2.訓練數據集和測試數據集的長度
- 3.搭建網絡模型
- 4.創建網絡模型、損失函數以及優化器
- 5.添加tensorboard
- 6.設置訓練網絡的一些參數
- 7.開始訓練模型
- 8.查看tensorboard的結果
- 模型驗證套路
- 1.輸入圖片
- 2.加載網絡模型
- 3.驗證結果
- 總結
前言
本周主要學習了Pytorch的使用,用Dataset讀取文件中的數據,DataLoader對Dataset讀取的數據進行分批次打包,tensorboard實現了對訓練Loss和測試Loss的可視化以及完成模型訓練和模型驗證。
模型訓練套路
1.準備數據集
我們需要準備訓練數據集和測試數據集,在Pytorch中,讀取數據集需要用到Dataset和DataLoader兩個類,Dataset負責對數據的讀取,讀取的內容是每一個數據和它對應的標簽;DataLoader負責對Dataset讀取的數據進行打包,然后分批次送入神經網絡。
import time
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import *
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
download=True)
# 利用 DataLoader 來加載數據集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
2.訓練數據集和測試數據集的長度
# length 長度
train_data_size = len(train_data)
test_data_size = len(test_data)
# if train_data_size = 10,訓練數據集的長度為:10
print("訓練數據集的長度為:{}".format(train_data_size))
print("測試數據集的長度為:{}".format(test_data_size))
得到訓練數據集和測試數據的長度
3.搭建網絡模型
class Mose(nn.Module):
def __init__(self):
super(Mose, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, 1, 2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 4 * 4, 64),
nn.Linear(64, 10)
)
def forward(self, x):
x = self.model(x)
return x
4.創建網絡模型、損失函數以及優化器
# 創建網絡模型
mose = Mose()
# 損失函數
loss_fn = nn.CrossEntropyLoss()
# 優化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(mose.parameters(), lr=learning_rate)
5.添加tensorboard
通過tensorboard記錄訓練過程與測試過程的變化
writer = SummaryWriter("../logs_train")
6.設置訓練網絡的一些參數
定義訓練的次數、測試的次數、訓練的輪數、開始時間以及結束時間
# 記錄訓練的次數
total_train_step = 0
# 記錄測試的次數
total_test_step = 0
# 訓練的輪數
epoch = 10
# 開始訓練時間
start_time = time.time()
7.開始訓練模型
每一輪訓練結束后,保存訓練好的模型。
for i in range(epoch):
print("--------第 {} 輪訓練開始--------".format(i+1))
# 訓練步驟開始
mose.train()
for data in train_dataloader:
imgs, targets = data
ouputs = mose(imgs)
loss = loss_fn(ouputs, targets)
# 優化器優化模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_step = total_train_step + 1
if total_train_step % 100 == 0:
end_time = time.time()
print(end_time - start_time)
print("訓練次數:{},Loss:{}".format(total_train_step, loss.item()))
writer.add_scalar("train_loss", loss.item(), total_train_step)
# 測試步驟開始
mose.eval()
total_test_loss = 0
total_accuracy = 0
with torch.no_grad():
for data in test_dataloader:
imgs, targets = data
ouputs = mose(imgs)
loss = loss_fn(ouputs, targets)
total_test_loss = total_test_loss + loss
accuracy = (ouputs.argmax(1) == targets).sum()
total_accuracy = total_accuracy + accuracy
print("整體測試集上的Loss:{}".format(total_test_loss))
print("整體測試集上的正確率:{}".format(total_accuracy/test_data_size))
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1
torch.save(mose, "mose_{}.pth".format(i))
print("模型已保存")
writer.close()
查看其中某一輪的訓練結果
8.查看tensorboard的結果
查看train_loss和test_loss
模型驗證套路
1.輸入圖片
讀取一張圖片,把圖片轉成PIL類型,并對圖片進行Resize,把它轉化成32 * 32的圖片。
image_path = "../imgs/dog.jpeg"
image = Image.open(image_path)
print(image)
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)
查看輸出結果
2.加載網絡模型
之前的訓練過程中,我們保存了網絡模型,所以現在我們需要去加載網絡模型。
model = torch.load("mose_0.pth", map_location=torch.device('cpu'))
查看輸出結果
3.驗證結果
把圖片從3維轉換成4維,并用模型去加載圖片得到結果。
image = torch.reshape(image, (1, 3, 32, 32))
model.eval()
with torch.no_grad():
output = model(image)
print(output)
print(output.argmax(1))
查看輸出結果
查看數據集的類別
得到的結果是5類別,說明驗證結果正確。
總結
在本周的學習中,實現了模型訓練和模型驗證,并且掌握其中的訓練套路和驗證套路,學會了tensorboard實現可視化。
原文鏈接:https://blog.csdn.net/peaunt1/article/details/126917780
相關推薦
- 2023-06-03 C++11學習之右值引用和移動語義詳解_C 語言
- 2023-01-27 Python?Flask利用SocketIO庫實現圖表的繪制_python
- 2022-03-15 redis編譯報致命錯誤:jemalloc/jemalloc.h:沒有那個文件或目錄
- 2022-06-06 flutter 布局管理詳解
- 2022-10-06 詳解hive常見表結構_數據庫其它
- 2022-03-14 ffmpeg開發讀取目錄列表
- 2022-04-08 Swift使用表格組件實現單列表_Swift
- 2022-03-19 C++?OpenCV技術實戰之身份證離線識別_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支