網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
PyTorch詳解經(jīng)典網(wǎng)絡(luò)ResNet實(shí)現(xiàn)流程_python
作者:峽谷的小魚 ? 更新時(shí)間: 2022-06-30 編程語(yǔ)言簡(jiǎn)述
GoogleNet 和 VGG 等網(wǎng)絡(luò)證明了,更深度的網(wǎng)絡(luò)可以抽象出表達(dá)能力更強(qiáng)的特征,進(jìn)而獲得更強(qiáng)的分類能力。在深度網(wǎng)絡(luò)中,隨之網(wǎng)絡(luò)深度的增加,每層輸出的特征圖分辨率主要是高和寬越來(lái)越小,而深度逐漸增加。
深度的增加理論上能夠提升網(wǎng)絡(luò)的表達(dá)能力,但是對(duì)于優(yōu)化來(lái)說(shuō)就會(huì)產(chǎn)生梯度消失的問(wèn)題。在深度網(wǎng)絡(luò)中,反向傳播時(shí),梯度從輸出端向數(shù)據(jù)端逐層傳播,傳播過(guò)程中,梯度的累乘使得近數(shù)據(jù)段接近0值,使得網(wǎng)絡(luò)的訓(xùn)練失效。
為了解決梯度消失問(wèn)題,可以在網(wǎng)絡(luò)中加入BatchNorm,激活函數(shù)換成ReLU,一定程度緩解了梯度消失問(wèn)題。
深度增加的另一個(gè)問(wèn)題就是網(wǎng)絡(luò)的退化(Degradation of deep network)問(wèn)題。即,在現(xiàn)有網(wǎng)絡(luò)的基礎(chǔ)上,增加網(wǎng)絡(luò)的深度,理論上,只有訓(xùn)練到最佳情況,新網(wǎng)絡(luò)的性能應(yīng)該不會(huì)低于淺層的網(wǎng)絡(luò)。因?yàn)?,只要將新增加的層學(xué)習(xí)成恒等映射(identity mapping)就可以。換句話說(shuō),淺網(wǎng)絡(luò)的解空間是深的網(wǎng)絡(luò)的解空間的子集。但是由于Degradation問(wèn)題,更深的網(wǎng)絡(luò)并不一定好于淺層網(wǎng)絡(luò)。
Residual模塊的想法就是認(rèn)為的讓網(wǎng)絡(luò)實(shí)現(xiàn)這種恒等映射。如圖,殘差結(jié)構(gòu)在兩層卷積的基礎(chǔ)上,并行添加了一個(gè)分支,將輸入直接加到最后的ReLU激活函數(shù)之前,如果兩層卷積改變大量輸入的分辨率和通道數(shù),為了能夠相加,可以在添加的分支上使用1x1卷積來(lái)匹配尺寸。
殘差結(jié)構(gòu)
ResNet網(wǎng)絡(luò)有兩種殘差塊,一種是兩個(gè)3x3卷積,一種是1x1,3x3,1x1三個(gè)卷積網(wǎng)絡(luò)串聯(lián)成殘差模塊。
PyTorch 實(shí)現(xiàn):
class Residual_1(nn.Module):
r"""
18-layer, 34-layer 殘差塊
1. 使用了類似VGG的3×3卷積層設(shè)計(jì);
2. 首先使用兩個(gè)相同輸出通道數(shù)的3×3卷積層,后接一個(gè)批量規(guī)范化和ReLU激活函數(shù);
3. 加入跨過(guò)卷積層的通路,加到最后的ReLU激活函數(shù)前;
4. 如果要匹配卷積后的輸出的尺寸和通道數(shù),可以在加入的跨通路上使用1×1卷積;
"""
def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
r"""
parameters:
input_channels: 輸入的通道上數(shù)
num_channels: 輸出的通道數(shù)
use_1x1conv: 是否需要使用1x1卷積控制尺寸
stride: 第一個(gè)卷積的步長(zhǎng)
"""
super().__init__()
# 3×3卷積,strides控制分辨率是否縮小
self.conv1 = nn.Conv2d(input_channels,
num_channels,
kernel_size=3,
padding=1,
stride=strides)
# 3×3卷積,不改變分辨率
self.conv2 = nn.Conv2d(num_channels,
num_channels,
kernel_size=3,
padding=1)
# 使用 1x1 卷積變換輸入的分辨率和通道
if use_1x1conv:
self.conv3 = nn.Conv2d(input_channels,
num_channels,
kernel_size=1,
stride=strides)
else:
self.conv3 = None
# 批量規(guī)范化層
self.bn1 = nn.BatchNorm2d(num_channels)
self.bn2 = nn.BatchNorm2d(num_channels)
def forward(self, X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3:
X = self.conv3(X)
# print(X.shape)
Y += X
return F.relu(Y)
class Residual_2(nn.Module):
r"""
50-layer, 101-layer, 152-layer 殘差塊
1. 首先使用1x1卷積,ReLU激活函數(shù);
2. 然后用3×3卷積層,在接一個(gè)批量規(guī)范化,ReLU激活函數(shù);
3. 再接1x1卷積層;
4. 加入跨過(guò)卷積層的通路,加到最后的ReLU激活函數(shù)前;
5. 如果要匹配卷積后的輸出的尺寸和通道數(shù),可以在加入的跨通路上使用1×1卷積;
"""
def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
r"""
parameters:
input_channels: 輸入的通道上數(shù)
num_channels: 輸出的通道數(shù)
use_1x1conv: 是否需要使用1x1卷積控制尺寸
stride: 第一個(gè)卷積的步長(zhǎng)
"""
super().__init__()
# 1×1卷積,strides控制分辨率是否縮小
self.conv1 = nn.Conv2d(input_channels,
num_channels,
kernel_size=1,
padding=1,
stride=strides)
# 3×3卷積,不改變分辨率
self.conv2 = nn.Conv2d(num_channels,
num_channels,
kernel_size=3,
padding=1)
# 1×1卷積,strides控制分辨率是否縮小
self.conv3 = nn.Conv2d(input_channels,
num_channels,
kernel_size=1,
padding=1)
# 使用 1x1 卷積變換輸入的分辨率和通道
if use_1x1conv:
self.conv3 = nn.Conv2d(input_channels,
num_channels,
kernel_size=1,
stride=strides)
else:
self.conv3 = None
# 批量規(guī)范化層
self.bn1 = nn.BatchNorm2d(num_channels)
self.bn2 = nn.BatchNorm2d(num_channels)
def forward(self, X):
Y = F.relu(self.bn1(self.conv1(X)))
Y = F.relu(self.bn2(self.conv2(Y)))
Y = self.conv3(Y)
if self.conv3:
X = self.conv3(X)
# print(X.shape)
Y += X
return F.relu(Y)
ResNet有不同的網(wǎng)絡(luò)層數(shù),比較常用的是50-layer,101-layer,152-layer。他們都是由上述的殘差模塊堆疊在一起實(shí)現(xiàn)的。
以18-layer為例,層數(shù)是指:首先,conv_1 的一層7x7卷積,然后conv_2~conv_5四個(gè)模塊,每個(gè)模塊兩個(gè)殘差塊,每個(gè)殘差塊有兩層的3x3卷積組成,共4×2×2=16層,最后是一層分類層(fc),加總一起共1+16+1=18層。
18-layer 實(shí)現(xiàn)
首先定義由殘差結(jié)構(gòu)組成的模塊:
# ResNet模塊
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
r"""殘差塊組成的模塊"""
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(Residual_1(input_channels,
num_channels,
use_1x1conv=True,
strides=2))
else:
blk.append(Residual_1(num_channels, num_channels))
return blk
定義18-layer的最開始的層:
# ResNet的前兩層:
# 1. 輸出通道數(shù)64, 步幅為2的7x7卷積層
# 2. 步幅為2的3x3最大匯聚層
conv_1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
定義殘差組模塊:
# ResNet模塊
conv_2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
conv_3 = nn.Sequential(*resnet_block(64, 128, 2))
conv_4 = nn.Sequential(*resnet_block(128, 256, 2))
conv_5 = nn.Sequential(*resnet_block(256, 512, 2))
ResNet 18-layer模型:
net = nn.Sequential(conv_1, conv_2, conv_3, conv_4, conv_5,
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(512, 10))
# 觀察模型各層的輸出尺寸
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
X = layer(X)
print(layer.__class__.__name__,'output shape:\t', X.shape)
輸出:
Sequential output shape:?? ? torch.Size([1, 64, 56, 56])
Sequential output shape:?? ? torch.Size([1, 64, 56, 56])
Sequential output shape:?? ? torch.Size([1, 128, 28, 28])
Sequential output shape:?? ? torch.Size([1, 256, 14, 14])
Sequential output shape:?? ? torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:?? ? torch.Size([1, 512, 1, 1])
Flatten output shape:?? ? torch.Size([1, 512])
Linear output shape:?? ? torch.Size([1, 10])
在數(shù)據(jù)集訓(xùn)練
def load_datasets_Cifar10(batch_size, resize=None):
trans = [transforms.ToTensor()]
if resize:
transform = trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=trans, download=True)
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=trans, download=True)
print("Cifar10 下載完成...")
return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True),
torch.utils.data.DataLoader(test_data, batch_size, shuffle=False))
def load_datasets_FashionMNIST(batch_size, resize=None):
trans = [transforms.ToTensor()]
if resize:
transform = trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
train_data = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
test_data = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
print("FashionMNIST 下載完成...")
return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True),
torch.utils.data.DataLoader(test_data, batch_size, shuffle=False))
def load_datasets(dataset, batch_size, resize):
if dataset == "Cifar10":
return load_datasets_Cifar10(batch_size, resize=resize)
else:
return load_datasets_FashionMNIST(batch_size, resize=resize)
train_iter, test_iter = load_datasets("", 128, 224) # Cifar10
原文鏈接:https://blog.csdn.net/weixin_43276033/article/details/124564891
相關(guān)推薦
- 2022-09-09 Python?OpenCV?Canny邊緣檢測(cè)算法的原理實(shí)現(xiàn)詳解_python
- 2022-05-15 Redis中有序集合的內(nèi)部實(shí)現(xiàn)方式的詳細(xì)介紹_Redis
- 2021-12-13 C語(yǔ)言數(shù)據(jù)結(jié)構(gòu)與算法之鏈表(二)_C 語(yǔ)言
- 2022-12-23 利用Python實(shí)現(xiàn)文件讀取與輸入以及數(shù)據(jù)存儲(chǔ)與讀取的常用命令_python
- 2023-07-10 Spring事務(wù)的傳播機(jī)制
- 2022-09-22 get方法和post方法的區(qū)別
- 2022-12-30 解決React報(bào)錯(cuò)useNavigate()?may?be?used?only?in?context
- 2022-08-02 Python+Selenium實(shí)現(xiàn)瀏覽器標(biāo)簽頁(yè)的切換_python
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支