日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

Pytorch模型訓練和模型驗證

作者:MoxiMoses 更新時間: 2022-09-26 編程語言

文章目錄

  • 前言
  • 模型訓練套路
    • 1.準備數據集
    • 2.訓練數據集和測試數據集的長度
    • 3.搭建網絡模型
    • 4.創建網絡模型、損失函數以及優化器
    • 5.添加tensorboard
    • 6.設置訓練網絡的一些參數
    • 7.開始訓練模型
    • 8.查看tensorboard的結果
  • 模型驗證套路
    • 1.輸入圖片
    • 2.加載網絡模型
    • 3.驗證結果
  • 總結


前言

本周主要學習了Pytorch的使用,用Dataset讀取文件中的數據,DataLoader對Dataset讀取的數據進行分批次打包,tensorboard實現了對訓練Loss和測試Loss的可視化以及完成模型訓練和模型驗證。


模型訓練套路

1.準備數據集

我們需要準備訓練數據集和測試數據集,在Pytorch中,讀取數據集需要用到Dataset和DataLoader兩個類,Dataset負責對數據的讀取,讀取的內容是每一個數據和它對應的標簽;DataLoader負責對Dataset讀取的數據進行打包,然后分批次送入神經網絡。

import time

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import *


train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)

test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)


# 利用 DataLoader 來加載數據集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

2.訓練數據集和測試數據集的長度

# length 長度
train_data_size = len(train_data)
test_data_size = len(test_data)

# if train_data_size = 10,訓練數據集的長度為:10
print("訓練數據集的長度為:{}".format(train_data_size))
print("測試數據集的長度為:{}".format(test_data_size))

得到訓練數據集和測試數據的長度
在這里插入圖片描述

3.搭建網絡模型

class Mose(nn.Module):
    def __init__(self):
        super(Mose, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

4.創建網絡模型、損失函數以及優化器

# 創建網絡模型
mose = Mose()

# 損失函數
loss_fn = nn.CrossEntropyLoss()

# 優化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(mose.parameters(), lr=learning_rate)

5.添加tensorboard

通過tensorboard記錄訓練過程與測試過程的變化

writer = SummaryWriter("../logs_train")

6.設置訓練網絡的一些參數

定義訓練的次數、測試的次數、訓練的輪數、開始時間以及結束時間

# 記錄訓練的次數
total_train_step = 0
# 記錄測試的次數
total_test_step = 0
# 訓練的輪數
epoch = 10
# 開始訓練時間
start_time = time.time()

7.開始訓練模型

每一輪訓練結束后,保存訓練好的模型。

for i in range(epoch):
    print("--------第 {} 輪訓練開始--------".format(i+1))

    # 訓練步驟開始
    mose.train()
    for data in train_dataloader:
        imgs, targets = data
        ouputs = mose(imgs)
        loss = loss_fn(ouputs, targets)

        # 優化器優化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            end_time = time.time()
            print(end_time - start_time)
            print("訓練次數:{},Loss:{}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)

    # 測試步驟開始
    mose.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            ouputs = mose(imgs)
            loss = loss_fn(ouputs, targets)
            total_test_loss = total_test_loss + loss
            accuracy = (ouputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy

        print("整體測試集上的Loss:{}".format(total_test_loss))
        print("整體測試集上的正確率:{}".format(total_accuracy/test_data_size))
        writer.add_scalar("test_loss", total_test_loss, total_test_step)
        writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
        total_test_step = total_test_step + 1

        torch.save(mose, "mose_{}.pth".format(i))
        print("模型已保存")

writer.close()

查看其中某一輪的訓練結果
在這里插入圖片描述

8.查看tensorboard的結果

查看train_loss和test_loss
在這里插入圖片描述

模型驗證套路

1.輸入圖片

讀取一張圖片,把圖片轉成PIL類型,并對圖片進行Resize,把它轉化成32 * 32的圖片。
在這里插入圖片描述

image_path = "../imgs/dog.jpeg"
image = Image.open(image_path)
print(image)

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])

image = transform(image)
print(image.shape)

查看輸出結果
在這里插入圖片描述

2.加載網絡模型

之前的訓練過程中,我們保存了網絡模型,所以現在我們需要去加載網絡模型。

model = torch.load("mose_0.pth", map_location=torch.device('cpu'))

查看輸出結果
在這里插入圖片描述

3.驗證結果

把圖片從3維轉換成4維,并用模型去加載圖片得到結果。

image = torch.reshape(image, (1, 3, 32, 32))
model.eval()
with torch.no_grad():
    output = model(image)
print(output)

print(output.argmax(1))

查看輸出結果
在這里插入圖片描述
查看數據集的類別
在這里插入圖片描述
得到的結果是5類別,說明驗證結果正確。


總結

在本周的學習中,實現了模型訓練和模型驗證,并且掌握其中的訓練套路和驗證套路,學會了tensorboard實現可視化。

原文鏈接:https://blog.csdn.net/peaunt1/article/details/126917780

欄目分類
最近更新