網站首頁 編程語言 正文
1. 代碼講解
1.1 導庫
import os.path from os import listdir import numpy as np import pandas as pd from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.nn import AdaptiveAvgPool2d from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data import Dataset import torchvision.transforms as transforms from sklearn.model_selection import train_test_split
1.2 標準化、transform、設置GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') normalize = transforms.Normalize( ? ?mean=[0.485, 0.456, 0.406], ? ?std=[0.229, 0.224, 0.225] ) transform = transforms.Compose([transforms.ToTensor(), normalize]) ?# 轉換
1.3 預處理數據
class DogDataset(Dataset): # 定義變量 ? ? def __init__(self, img_paths, img_labels, size_of_images): ? ? ? ? ? self.img_paths = img_paths ? ? ? ? self.img_labels = img_labels ? ? ? ? self.size_of_images = size_of_images # 多少長圖片 ? ? def __len__(self): ? ? ? ? return len(self.img_paths) # 打開每組圖片并處理每張圖片 ? ? def __getitem__(self, index): ? ? ? ? PIL_IMAGE = Image.open(self.img_paths[index]).resize(self.size_of_images) ? ? ? ? TENSOR_IMAGE = transform(PIL_IMAGE) ? ? ? ? label = self.img_labels[index] ? ? ? ? return TENSOR_IMAGE, label print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train'))) print(len(pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv'))) print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\test'))) train_paths = [] test_paths = [] labels = [] # 訓練集圖片路徑 train_paths_lir = r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train' for path in listdir(train_paths_lir): ? ? train_paths.append(os.path.join(train_paths_lir, path)) ? # 測試集圖片路徑 labels_data = pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv') labels_data = pd.DataFrame(labels_data) ? # 把字符標簽離散化,因為數據有120種狗,不離散化后面把數據給模型時會報錯:字符標簽過多。把字符標簽從0-119編號 size_mapping = {} value = 0 size_mapping = dict(labels_data['breed'].value_counts()) for kay in size_mapping: ? ? size_mapping[kay] = value ? ? value += 1 # print(size_mapping) labels = labels_data['breed'].map(size_mapping) labels = list(labels) # print(labels) print(len(labels)) # 劃分訓練集和測試集 X_train, X_test, y_train, y_test = train_test_split(train_paths, labels, test_size=0.2) train_set = DogDataset(X_train, y_train, (32, 32)) test_set = DogDataset(X_test, y_test, (32, 32)) train_loader = torch.utils.data.DataLoader(train_set, batch_size=64) test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)
1.4 建立模型
class LeNet(nn.Module): ? ? def __init__(self): ? ? ? ? super(LeNet, self).__init__() ? ? ? ? self.features = nn.Sequential( ? ? ? ? ? ? nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5), ? ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2, stride=2), ? ? ? ? ? ? nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2, stride=2) ? ? ? ? ) ? ? ? ? self.classifier = nn.Sequential( ? ? ? ? ? ? nn.Linear(16 * 5 * 5, 120), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.Linear(120, 84), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.Linear(84, 120) ? ? ? ? ) ? ? def forward(self, x): ? ? ? ? batch_size = x.shape[0] ? ? ? ? x = self.features(x) ? ? ? ? x = x.view(batch_size, -1) ? ? ? ? x = self.classifier(x) ? ? ? ? return x model = LeNet().to(device) criterion = nn.CrossEntropyLoss().to(device) optimizer = optim.Adam(model.parameters()) TRAIN_LOSS = [] ?# 損失 TRAIN_ACCURACY = [] ?# 準確率
1.5 訓練模型
def train(epoch): ? ? model.train() ? ? epoch_loss = 0.0 # 損失 ? ? correct = 0 ?# 精確率 ? ? for batch_index, (Data, Label) in enumerate(train_loader): ? ? # 扔到GPU中 ? ? ? ? Data = Data.to(device) ? ? ? ? Label = Label.to(device) ? ? ? ? output_train = model(Data) ? ? # 計算損失 ? ? ? ? loss_train = criterion(output_train, Label) ? ? ? ? epoch_loss = epoch_loss + loss_train.item() ? ? # 計算精確率 ? ? ? ? pred = torch.max(output_train, 1)[1] ? ? ? ? train_correct = (pred == Label).sum() ? ? ? ? correct = correct + train_correct.item() ? ? # 梯度歸零、反向傳播、更新參數 ? ? ? ? optimizer.zero_grad() ? ? ? ? loss_train.backward() ? ? ? ? optimizer.step() ? ? print('Epoch: ', epoch, 'Train_loss: ', epoch_loss / len(train_set), 'Train correct: ', correct / len(train_set))
1.6 測試模型
和訓練集差不多。
def test(): ? ? model.eval() ? ? correct = 0.0 ? ? test_loss = 0.0 ? ? with torch.no_grad(): ? ? ? ? for Data, Label in test_loader: ? ? ? ? ? ? Data = Data.to(device) ? ? ? ? ? ? Label = Label.to(device) ? ? ? ? ? ? test_output = model(Data) ? ? ? ? ? ? loss = criterion(test_output, Label) ? ? ? ? ? ? pred = torch.max(test_output, 1)[1] ? ? ? ? ? ? test_correct = (pred == Label).sum() ? ? ? ? ? ? correct = correct + test_correct.item() ? ? ? ? ? ? test_loss = test_loss + loss.item() ? ? print('Test_loss: ', test_loss / len(test_set), 'Test correct: ', correct / len(test_set))
1.7結果
epoch = 10 for n_epoch in range(epoch): ? ? train(n_epoch) test()
原文鏈接:https://blog.csdn.net/weixin_45758642/article/details/119764959
相關推薦
- 2022-10-13 Python?數據分析教程探索性數據分析_python
- 2022-03-19 Go?語言的?:=的具體使用_Golang
- 2022-11-09 Android開發實現圖片的上傳下載_Android
- 2022-04-09 Android項目中gradle的執行流程_Android
- 2022-10-07 C#如何實現調取釘釘考勤接口的功能_C#教程
- 2022-05-05 Python學習之字典的常用方法總結_python
- 2022-07-06 C語言實現字符串字符反向排列的方法詳解_C 語言
- 2021-12-07 Linux下Hbase安裝配置教程_Linux
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支