網站首頁 編程語言 正文
PyTorch中實現卷積的重要基礎函數
1、nn.Conv2d:
nn.Conv2d在pytorch中用于實現卷積。
nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
)
1、in_channels為輸入通道數。
2、out_channels為輸出通道數。
3、kernel_size為卷積核大小。
4、stride為步數。
5、padding為padding情況。
6、dilation表示空洞卷積情況。
2、nn.MaxPool2d(kernel_size=2)
nn.MaxPool2d在pytorch中用于實現最大池化。
具體使用方式如下:
MaxPool2d(kernel_size,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode=False)
1、kernel_size為池化核的大小
2、stride為步長
3、padding為填充情況
3、nn.ReLU()
nn.ReLU()用來實現Relu函數,實現非線性。
4、x.view()
x.view用于reshape特征層的形狀。
全部代碼
這是一個簡單的CNN模型,用于預測mnist手寫體。
import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
# 循環世代
EPOCH = 20
BATCH_SIZE = 50
# 下載mnist數據集
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=True,)
# (60000, 28, 28)
print(train_data.train_data.size())
# (60000)
print(train_data.train_labels.size())
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 測試集
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# (2000, 1, 28, 28)
# 標準化
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
# 建立pytorch神經網絡
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
#----------------------------#
# 第一部分卷積
#----------------------------#
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=32,
kernel_size=5,
stride=1,
padding=2,
dilation=1
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
#----------------------------#
# 第二部分卷積
#----------------------------#
self.conv2 = nn.Sequential(
nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
dilation=1
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
#----------------------------#
# 全連接+池化+全連接
#----------------------------#
self.ful1 = nn.Linear(64 * 7 * 7, 512)
self.drop = nn.Dropout(0.5)
self.ful2 = nn.Sequential(nn.Linear(512, 10),nn.Softmax())
#----------------------------#
# 前向傳播
#----------------------------#
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.ful1(x)
x = self.drop(x)
output = self.ful2(x)
return output
cnn = CNN()
# 指定優化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
# 指定loss函數
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader):
#----------------------------#
# 計算loss并修正權值
#----------------------------#
output = cnn(b_x)
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#----------------------------#
# 打印
#----------------------------#
if step % 50 == 0:
test_output = cnn(test_x)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
print('Epoch: %2d'% epoch, ', loss: %.4f' % loss.data.numpy(), ', accuracy: %.4f' % accuracy)
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/103658845
相關推薦
- 2022-05-22 Python數據結構之隊列詳解_python
- 2024-01-30 深入理解Scrapy中XPath的`following-sibling`選擇器
- 2021-11-09 Android如何實現時間線效果(下)_Android
- 2022-06-09 Python字符串的轉義字符_python
- 2022-03-29 python中format函數與round函數的區別_python
- 2022-01-28 laravel try異常abort只報出最外層
- 2022-12-07 C語言?如何求兩整數的最大公約數與最小公倍數_C 語言
- 2022-05-17 ubuntu安裝curl時,出現configure: error: select TLS backe
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支