網站首頁 編程語言 正文
含并行連結的網絡 GoogLeNet
在GoogleNet出現值前,流行的網絡結構使用的卷積核從1×1到11×11,卷積核的選擇并沒有太多的原因。GoogLeNet的提出,說明有時候使用多個不同大小的卷積核組合是有利的。
import torch
from torch import nn
from torch.nn import functional as F
1. Inception塊
Inception塊是 GoogLeNet 的基本組成單元。Inception 塊由四條并行的路徑組成,每個路徑使用不同大小的卷積核:
路徑1:使用 1×1 卷積層;
路徑2:先對輸出執行 1×1 卷積層,來減少通道數,降低模型復雜性,然后接 3×3 卷積層;
路徑3:先對輸出執行 1×1 卷積層,然后接 5×5 卷積層;
路徑4:使用 3×3 最大匯聚層,然后使用 1×1 卷積層;
在各自路徑中使用合適的 padding ,使得各個路徑的輸出擁有相同的高和寬,然后將每條路徑的輸出在通道維度上做連結,作為 Inception 塊的最終輸出.
class Inception(nn.Module):
def __init__(self, in_channels, out_channels):
super(Inception, self).__init__()
# 路徑1
c1, c2, c3, c4 = out_channels
self.route1_1 = nn.Conv2d(in_channels, c1, kernel_size=1)
# 路徑2
self.route2_1 = nn.Conv2d(in_channels, c2[0], kernel_size=1)
self.route2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)
# 路徑3
self.route3_1 = nn.Conv2d(in_channels, c3[0], kernel_size=1)
self.route3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)
# 路徑4
self.route4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
self.route4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)
def forward(self, x):
x1 = F.relu(self.route1_1(x))
x2 = F.relu(self.route2_2(F.relu(self.route2_1(x))))
x3 = F.relu(self.route3_2(F.relu(self.route3_1(x))))
x4 = F.relu(self.route4_2(self.route4_1(x)))
return torch.cat((x1, x2, x3, x4), dim=1)
2. 構造 GoogLeNet 網絡
順序定義 GoogLeNet 的模塊。
第一個模塊,順序使用三個卷積層。
# 模型的第一個模塊
b1 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(64, 64, kernel_size=1),
nn.ReLU(),
nn.Conv2d(64, 192, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
第二個模塊,使用兩個Inception模塊。
# Inception組成的第二個模塊
b2 = nn.Sequential(
Inception(192, (64, (96, 128), (16, 32), 32)),
Inception(256, (128, (128, 192), (32, 96), 64)),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
第三個模塊,串聯五個Inception模塊。
# Inception組成的第三個模塊
b3 = nn.Sequential(
Inception(480, (192, (96, 208), (16, 48), 64)),
Inception(512, (160, (112, 224), (24, 64), 64)),
Inception(512, (128, (128, 256), (24, 64), 64)),
Inception(512, (112, (144, 288), (32, 64), 64)),
Inception(528, (256, (160, 320), (32, 128), 128)),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
第四個模塊,傳來兩個Inception模塊。
GoogLeNet使用 avg pooling layer 代替了 fully-connected layer。一方面降低了維度,另一方面也可以視為對低層特征的組合。
# Inception組成的第四個模塊
b4 = nn.Sequential(
Inception(832, (256, (160, 320), (32, 128), 128)),
Inception(832, (384, (192, 384), (48, 128), 128)),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten()
)
net = nn.Sequential(b1, b2, b3, b4, nn.Linear(1024, 10))
x = torch.randn(1, 1, 96, 96)
for layer in net:
x = layer(x)
print(layer.__class__.__name__, "output shape: ", x.shape)
輸出:
Sequential output shape: ?torch.Size([1, 192, 28, 28])
Sequential output shape: ?torch.Size([1, 480, 14, 14])
Sequential output shape: ?torch.Size([1, 832, 7, 7])
Sequential output shape: ?torch.Size([1, 1024])
Linear output shape: ?torch.Size([1, 10])
3. FashionMNIST訓練測試
def load_datasets_Cifar10(batch_size, resize=None):
trans = [transforms.ToTensor()]
if resize:
transform = trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=trans, download=True)
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=trans, download=True)
print("Cifar10 下載完成...")
return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True),
torch.utils.data.DataLoader(test_data, batch_size, shuffle=False))
def load_datasets_FashionMNIST(batch_size, resize=None):
trans = [transforms.ToTensor()]
if resize:
transform = trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
train_data = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
test_data = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
print("FashionMNIST 下載完成...")
return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True),
torch.utils.data.DataLoader(test_data, batch_size, shuffle=False))
def load_datasets(dataset, batch_size, resize):
if dataset == "Cifar10":
return load_datasets_Cifar10(batch_size, resize=resize)
else:
return load_datasets_FashionMNIST(batch_size, resize=resize)
train_iter, test_iter = load_datasets("", 128, 96) # Cifar10
訓練結果:
原文鏈接:https://blog.csdn.net/weixin_43276033/article/details/124545743
相關推薦
- 2022-04-27 .Net?Core中使用MongoDB搭建集群與項目實戰_基礎應用
- 2022-04-28 淺析python中特殊文件和特殊函數_python
- 2022-05-29 Docker向數據卷Volume寫入數據_docker
- 2022-12-30 React?Refs轉發實現流程詳解_React
- 2022-12-02 基于Go語言實現類似tree命令的小程序_Golang
- 2022-09-02 Docker資源限制Cgroup的深入理解_docker
- 2023-07-07 JdbcTemplate基本使用
- 2022-04-15 Python3之字符串比較_重寫cmp函數方式_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支