網(wǎng)站首頁 編程語言 正文
LeNet網(wǎng)絡(luò)
LeNet網(wǎng)絡(luò)過卷積層時候保持分辨率不變,過池化層時候分辨率變小。實現(xiàn)如下
from PIL import Image import cv2 import matplotlib.pyplot as plt import torchvision from torchvision import transforms import torch from torch.utils.data import DataLoader import torch.nn as nn import numpy as np import tqdm as tqdm class LeNet(nn.Module): ? ? def __init__(self) -> None: ? ? ? ? super().__init__() ? ? ? ? self.sequential = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(), ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2,stride=2), ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(), ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2,stride=2), ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Flatten(), ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(16*25,120),nn.Sigmoid(), ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(120,84),nn.Sigmoid(), ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(84,10)) ? ? ? ?? ? ?? ? ? def forward(self,x): ? ? ? ? return self.sequential(x) class MLP(nn.Module): ? ? def __init__(self) -> None: ? ? ? ? super().__init__() ? ? ? ? self.sequential = nn.Sequential(nn.Flatten(), ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(28*28,120),nn.Sigmoid(), ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(120,84),nn.Sigmoid(), ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(84,10)) ? ? ? ?? ? ?? ? ? def forward(self,x): ? ? ? ? return self.sequential(x) epochs = 15 batch = 32 lr=0.9 loss = nn.CrossEntropyLoss() model = LeNet() optimizer = torch.optim.SGD(model.parameters(),lr) device = torch.device('cuda') root = r"./" trans_compose ?= transforms.Compose([transforms.ToTensor(), ? ? ? ? ? ? ? ? ? ? ]) train_data = torchvision.datasets.MNIST(root,train=True,transform=trans_compose,download=True) test_data = torchvision.datasets.MNIST(root,train=False,transform=trans_compose,download=True) train_loader = DataLoader(train_data,batch_size=batch,shuffle=True) test_loader = DataLoader(test_data,batch_size=batch,shuffle=False) model.to(device) loss.to(device) # model.apply(init_weights) for epoch in range(epochs): ? train_loss = 0 ? test_loss = 0 ? correct_train = 0 ? correct_test = 0 ? for index,(x,y) in enumerate(train_loader): ? ? x = x.to(device) ? ? y = y.to(device) ? ? predict = model(x) ? ? L = loss(predict,y) ? ? optimizer.zero_grad() ? ? L.backward() ? ? optimizer.step() ? ? train_loss = train_loss + L ? ? correct_train += (predict.argmax(dim=1)==y).sum() ? acc_train = correct_train/(batch*len(train_loader)) ? with torch.no_grad(): ? ? for index,(x,y) in enumerate(test_loader): ? ? ? [x,y] = [x.to(device),y.to(device)] ? ? ? predict = model(x) ? ? ? L1 = loss(predict,y) ? ? ? test_loss = test_loss + L1 ? ? ? correct_test += (predict.argmax(dim=1)==y).sum() ? ? acc_test = correct_test/(batch*len(test_loader)) ? print(f'epoch:{epoch},train_loss:{train_loss/batch},test_loss:{test_loss/batch},acc_train:{acc_train},acc_test:{acc_test}')
訓(xùn)練結(jié)果
epoch:12,train_loss:2.235553741455078,test_loss:0.3947642743587494,acc_train:0.9879833459854126,acc_test:0.9851238131523132
epoch:13,train_loss:2.028963804244995,test_loss:0.3220392167568207,acc_train:0.9891499876976013,acc_test:0.9875199794769287
epoch:14,train_loss:1.8020273447036743,test_loss:0.34837451577186584,acc_train:0.9901833534240723,acc_test:0.98702073097229
泛化能力測試
找了一張圖片,將其分割成只含一個數(shù)字的圖片進行測試
images_np = cv2.imread("/content/R-C.png",cv2.IMREAD_GRAYSCALE) h,w = images_np.shape images_np = np.array(255*torch.ones(h,w))-images_np#圖片反色 images = Image.fromarray(images_np) plt.figure(1) plt.imshow(images) test_images = [] for i in range(10): ? for j in range(16): ? ? test_images.append(images_np[h//10*i:h//10+h//10*i,w//16*j:w//16*j+w//16]) sample = test_images[77] sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28)) predict = model(sample_tensor) output = predict.argmax() print(output) plt.figure(2) plt.imshow(np.array(sample_tensor.squeeze().to('cpu')))
此時預(yù)測結(jié)果為4,預(yù)測正確。從這段代碼中可以看到有一個反色的步驟,若不反色,結(jié)果會受到影響,如下圖所示,預(yù)測為0,錯誤。
模型用于輸入的圖片是單通道的黑白圖片,這里由于可視化出現(xiàn)了黃色,但實際上是黑白色,反色操作說明了數(shù)據(jù)的預(yù)處理十分的重要,很多數(shù)據(jù)如果是不清理過是無法直接用于推理的。
將所有用來泛化性測試的圖片進行準(zhǔn)確率測試:
correct = 0 i = 0 cnt = 1 for sample in test_images: ? sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device) ? sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28)) ? predict = model(sample_tensor) ? output = predict.argmax() ? if(output==i): ? ? correct+=1 ? if(cnt%16==0): ? ? i+=1 ? cnt+=1 acc_g = correct/len(test_images) print(f'acc_g:{acc_g}')
如果不反色,acc_g=0.15
acc_g:0.50625
原文鏈接:https://blog.csdn.net/weixin_44823313/article/details/122581741
相關(guān)推薦
- 2022-07-07 Python如何在列表尾部添加元素_python
- 2022-11-30 golang中的defer函數(shù)理解_Golang
- 2022-04-23 .NET?Core使用APB?vNext框架入門教程_實用技巧
- 2022-03-29 python語法?range()?序列類型range_python
- 2022-08-31 C++?OpenCV裁剪圖片時發(fā)生報錯的解決方式_C 語言
- 2022-06-12 Python語法學(xué)習(xí)之線程的創(chuàng)建與常用方法詳解_python
- 2022-07-06 YOLOv5目標(biāo)檢測之a(chǎn)nchor設(shè)定_python
- 2022-09-02 C#實現(xiàn)裝飾器模式_C#教程
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支