網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
Pytorch卷積神經(jīng)網(wǎng)絡(luò)遷移學(xué)習(xí)的目標(biāo)及好處_python
作者:淺念念52 ? 更新時(shí)間: 2022-07-07 編程語(yǔ)言前言
在深度學(xué)習(xí)訓(xùn)練的過(guò)程中,隨著網(wǎng)絡(luò)層數(shù)的提升,我們訓(xùn)練的次數(shù),參數(shù)都會(huì)提高,訓(xùn)練時(shí)間相應(yīng)就會(huì)增加,我們今天來(lái)了解遷移學(xué)習(xí)
一、經(jīng)典的卷積神經(jīng)網(wǎng)絡(luò)
在pytorch官網(wǎng)中,我們可以看到許多經(jīng)典的卷積神經(jīng)網(wǎng)絡(luò)。
附官網(wǎng)鏈接:https://pytorch.org/
這里簡(jiǎn)單介紹一下經(jīng)典的卷積神經(jīng)發(fā)展歷程
1.首先可以說(shuō)是卷積神經(jīng)網(wǎng)絡(luò)的開(kāi)山之作Alexnet(12年的奪冠之作)這里簡(jiǎn)單說(shuō)一下缺點(diǎn) 卷積核大,步長(zhǎng)大,沒(méi)有填充層,大刀闊斧的提取特征,容易忽略一些重要的特征
2.第二個(gè)就是VGG網(wǎng)絡(luò),它的卷積核大小是3*3,有一個(gè)優(yōu)點(diǎn)是經(jīng)過(guò)池化層之后,通道數(shù)翻倍,可以更多的保留一些特征,這是VGG的一個(gè)特點(diǎn)
在接下來(lái)的一段時(shí)間中,出現(xiàn)了一個(gè)問(wèn)題,我們都知道,深度學(xué)習(xí)隨著訓(xùn)練次數(shù)的不斷增加,效果應(yīng)該是越來(lái)越好,但是這里出現(xiàn)了一個(gè)問(wèn)題,研究發(fā)現(xiàn)隨著VGG網(wǎng)絡(luò)的不斷提高,效果卻沒(méi)有原來(lái)的好,這時(shí)候人們就認(rèn)為,深度學(xué)習(xí)是不是只能發(fā)展到這里了,這時(shí)遇到了一個(gè)瓶頸。
3.接下來(lái)隨著殘差網(wǎng)絡(luò)(Resnet)的提出,解決了上面這個(gè)問(wèn)題,這個(gè)網(wǎng)絡(luò)的優(yōu)點(diǎn)是保留了原有的特征,假如經(jīng)過(guò)卷積之后提取的特征還沒(méi)有原圖的好,這時(shí)候保留原有的特征,就會(huì)解決這一問(wèn)題,下面就是resnet網(wǎng)絡(luò)模型
這是一些訓(xùn)練對(duì)比:
二、遷移學(xué)習(xí)的目標(biāo)
首先我們使用遷移學(xué)習(xí)的目標(biāo)就是用人家訓(xùn)練好的權(quán)重參數(shù),偏置參數(shù),來(lái)訓(xùn)練我們的模型。
三、好處
深度學(xué)習(xí)要訓(xùn)練的數(shù)據(jù)量是很大的,當(dāng)我們數(shù)據(jù)量少時(shí),我們訓(xùn)練的權(quán)重參數(shù)就不會(huì)那么的好,所以這時(shí)候我們就可以使用別人訓(xùn)練好的權(quán)重參數(shù),偏置參數(shù)來(lái)使用,會(huì)使我們的模型準(zhǔn)確率得到提高
四、步驟
遷移學(xué)習(xí)大致可以分為三步
1.加載模型
2.凍結(jié)層數(shù)
3.全連接層
五、代碼
這里使用的是resnet152
import torch
import torchvision as tv
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch
from torch.utils import data
from torch import optim
from torch.autograd import Variable
model_name='resnet'
featuer_extract=True
train_on_gpu=torch.cuda.is_available()
if not train_on_gpu:
print("沒(méi)有g(shù)pu")
else :
print("是gpu")
devic=torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
teature_extract=True
def set_paremeter_requires_grad(model,featuer_extract):
if featuer_extract:
for parm in model.parameters():
parm.requires_grad=False #不做訓(xùn)練
def initialize_model(model_name,num_classes,featuer_extract,use_pretrained=True):
model_ft = None
input_size = 0
if model_name=="resnet":
model_ft=tv.models.resnet152(pretrained=use_pretrained)#下載模型
set_paremeter_requires_grad(model_ft,featuer_extract) #凍結(jié)層數(shù)
num_ftrs=model_ft.fc.in_features #改動(dòng)全連接層
model_ft.fc=nn.Sequential(nn.Linear(num_ftrs,num_classes),
nn.LogSoftmax(dim=1))
input_size=224 #輸入維度
return model_ft,input_size
model_ft,iput_size=initialize_model(model_name,10,featuer_extract,use_pretrained=True)
model_ft=model_ft.to(devic)
params_to_updata=model_ft.parameters()
if featuer_extract:
params_to_updata=[]
for name,param in model_ft.named_parameters():
if param.requires_grad==True:
params_to_updata.append(param)
print("\t",name)
else:
for name,param in model_ft.parameters():
if param.requires_grad==True:
print("\t",name)
opt=optim.Adam(params_to_updata,lr=0.01)
loss=nn.NLLLoss()
if __name__ == '__main__':
transform = transforms.Compose([
# 圖像增強(qiáng)
transforms.Resize(1024),#裁剪
transforms.RandomHorizontalFlip(),#隨機(jī)水平翻轉(zhuǎn)
transforms.RandomCrop(224),#隨機(jī)裁剪
transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), #亮度
# 轉(zhuǎn)變?yōu)閠ensor 正則化
transforms.ToTensor(), #轉(zhuǎn)換格式
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # 歸一化處理
])
trainset = tv.datasets.CIFAR10(
root=r'E:\桌面\資料\cv3\數(shù)據(jù)集\cifar-10-batches-py',
train=True,
download=True,
transform=transform
)
trainloader = data.DataLoader(
trainset,
batch_size=8,
drop_last=True,
shuffle=True, # 亂序
num_workers=4,
)
testset = tv.datasets.CIFAR10(
root=r'E:\桌面\資料\cv3\數(shù)據(jù)集\cifar-10-batches-py',
train=False,
download=True,
transform=transform
)
testloader = data.DataLoader(
testset,
batch_size=8,
drop_last=True,
shuffle=False,
num_workers=4
)
for epoch in range(3):
running_loss=0
for index,data in enumerate(trainloader,0):
inputs, labels = data
inputs = inputs.to(devic)
labels = labels.to(devic)
inputs, labels = Variable(inputs), Variable(labels)
opt.zero_grad()
h=model_ft(inputs)
loss1=loss(h,labels)
loss1.backward()
opt.step()
h+=loss1.item()
if index%10==9:
avg_loss=loss1/10.
running_loss=0
print('avg_loss',avg_loss)
if index%100==99 :
correct=0
total=0
for data in testloader:
images,labels=data
outputs=model_ft(Variable(images.cuda()))
_,predicted=torch.max(outputs.cpu(),1)
total+=labels.size(0)
bool_tensor=(predicted==labels)
correct+=bool_tensor.sum()
print('1000張測(cè)試集中的準(zhǔn)確率為%d %%'%(100*correct/total))
原文鏈接:https://blog.csdn.net/Lightismore/article/details/124476720
相關(guān)推薦
- 2022-09-24 Go?GORM?事務(wù)詳細(xì)介紹_Golang
- 2022-09-22 vant tab組件動(dòng)態(tài)改變van-tab title后顯示不全問(wèn)題,需要手動(dòng)滑動(dòng)
- 2022-06-18 C語(yǔ)言簡(jiǎn)明講解單引號(hào)與雙引號(hào)的使用_C 語(yǔ)言
- 2022-08-15 創(chuàng)建型設(shè)計(jì)模式之建造者模式
- 2022-05-23 ZooKeeper分布式協(xié)調(diào)服務(wù)設(shè)計(jì)核心概念及安裝配置_zabbix
- 2022-01-29 yii 關(guān)聯(lián)表外鍵用法
- 2022-04-12 npm ERR! missing script: dev
- 2022-09-30 Qt編寫(xiě)秒表功能_C 語(yǔ)言
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門(mén)
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支