網站首頁 編程語言 正文
1. 代碼講解
1.1 導庫
import os.path from os import listdir import numpy as np import pandas as pd from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.nn import AdaptiveAvgPool2d from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data import Dataset import torchvision.transforms as transforms from sklearn.model_selection import train_test_split
1.2 標準化、transform、設置GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') normalize = transforms.Normalize( ? ?mean=[0.485, 0.456, 0.406], ? ?std=[0.229, 0.224, 0.225] ) transform = transforms.Compose([transforms.ToTensor(), normalize]) ?# 轉換
1.3 預處理數據
class DogDataset(Dataset): # 定義變量 ? ? def __init__(self, img_paths, img_labels, size_of_images): ? ? ? ? ? self.img_paths = img_paths ? ? ? ? self.img_labels = img_labels ? ? ? ? self.size_of_images = size_of_images # 多少長圖片 ? ? def __len__(self): ? ? ? ? return len(self.img_paths) # 打開每組圖片并處理每張圖片 ? ? def __getitem__(self, index): ? ? ? ? PIL_IMAGE = Image.open(self.img_paths[index]).resize(self.size_of_images) ? ? ? ? TENSOR_IMAGE = transform(PIL_IMAGE) ? ? ? ? label = self.img_labels[index] ? ? ? ? return TENSOR_IMAGE, label print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train'))) print(len(pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv'))) print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\test'))) train_paths = [] test_paths = [] labels = [] # 訓練集圖片路徑 train_paths_lir = r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train' for path in listdir(train_paths_lir): ? ? train_paths.append(os.path.join(train_paths_lir, path)) ? # 測試集圖片路徑 labels_data = pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv') labels_data = pd.DataFrame(labels_data) ? # 把字符標簽離散化,因為數據有120種狗,不離散化后面把數據給模型時會報錯:字符標簽過多。把字符標簽從0-119編號 size_mapping = {} value = 0 size_mapping = dict(labels_data['breed'].value_counts()) for kay in size_mapping: ? ? size_mapping[kay] = value ? ? value += 1 # print(size_mapping) labels = labels_data['breed'].map(size_mapping) labels = list(labels) # print(labels) print(len(labels)) # 劃分訓練集和測試集 X_train, X_test, y_train, y_test = train_test_split(train_paths, labels, test_size=0.2) train_set = DogDataset(X_train, y_train, (32, 32)) test_set = DogDataset(X_test, y_test, (32, 32)) train_loader = torch.utils.data.DataLoader(train_set, batch_size=64) test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)
1.4 建立模型
class LeNet(nn.Module): ? ? def __init__(self): ? ? ? ? super(LeNet, self).__init__() ? ? ? ? self.features = nn.Sequential( ? ? ? ? ? ? nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5), ? ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2, stride=2), ? ? ? ? ? ? nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2, stride=2) ? ? ? ? ) ? ? ? ? self.classifier = nn.Sequential( ? ? ? ? ? ? nn.Linear(16 * 5 * 5, 120), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.Linear(120, 84), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.Linear(84, 120) ? ? ? ? ) ? ? def forward(self, x): ? ? ? ? batch_size = x.shape[0] ? ? ? ? x = self.features(x) ? ? ? ? x = x.view(batch_size, -1) ? ? ? ? x = self.classifier(x) ? ? ? ? return x model = LeNet().to(device) criterion = nn.CrossEntropyLoss().to(device) optimizer = optim.Adam(model.parameters()) TRAIN_LOSS = [] ?# 損失 TRAIN_ACCURACY = [] ?# 準確率
1.5 訓練模型
def train(epoch): ? ? model.train() ? ? epoch_loss = 0.0 # 損失 ? ? correct = 0 ?# 精確率 ? ? for batch_index, (Data, Label) in enumerate(train_loader): ? ? # 扔到GPU中 ? ? ? ? Data = Data.to(device) ? ? ? ? Label = Label.to(device) ? ? ? ? output_train = model(Data) ? ? # 計算損失 ? ? ? ? loss_train = criterion(output_train, Label) ? ? ? ? epoch_loss = epoch_loss + loss_train.item() ? ? # 計算精確率 ? ? ? ? pred = torch.max(output_train, 1)[1] ? ? ? ? train_correct = (pred == Label).sum() ? ? ? ? correct = correct + train_correct.item() ? ? # 梯度歸零、反向傳播、更新參數 ? ? ? ? optimizer.zero_grad() ? ? ? ? loss_train.backward() ? ? ? ? optimizer.step() ? ? print('Epoch: ', epoch, 'Train_loss: ', epoch_loss / len(train_set), 'Train correct: ', correct / len(train_set))
1.6 測試模型
和訓練集差不多。
def test(): ? ? model.eval() ? ? correct = 0.0 ? ? test_loss = 0.0 ? ? with torch.no_grad(): ? ? ? ? for Data, Label in test_loader: ? ? ? ? ? ? Data = Data.to(device) ? ? ? ? ? ? Label = Label.to(device) ? ? ? ? ? ? test_output = model(Data) ? ? ? ? ? ? loss = criterion(test_output, Label) ? ? ? ? ? ? pred = torch.max(test_output, 1)[1] ? ? ? ? ? ? test_correct = (pred == Label).sum() ? ? ? ? ? ? correct = correct + test_correct.item() ? ? ? ? ? ? test_loss = test_loss + loss.item() ? ? print('Test_loss: ', test_loss / len(test_set), 'Test correct: ', correct / len(test_set))
1.7結果
epoch = 10 for n_epoch in range(epoch): ? ? train(n_epoch) test()
原文鏈接:https://blog.csdn.net/weixin_45758642/article/details/119764959
相關推薦
- 2022-12-04 python中的list字符串元素排序_python
- 2022-06-12 GitHub?AI編程工具copilot在Pycharm的應用_python
- 2021-10-24 Linux多線程中fork與互斥鎖過程示例_Linux
- 2023-07-03 前端面試中遇到的垂直居中問題
- 2022-06-01 Android調用外置攝像頭的方法_Android
- 2022-12-03 C?++迭代器iterator在string中使用方法介紹_C 語言
- 2022-04-03 Android?PopUpWindow實現卡片式彈窗_Android
- 2022-08-16 Hive導入csv文件示例_數據庫其它
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支