網站首頁 編程語言 正文
前言
在深度學習訓練的過程中,隨著網絡層數的提升,我們訓練的次數,參數都會提高,訓練時間相應就會增加,我們今天來了解遷移學習
一、經典的卷積神經網絡
在pytorch官網中,我們可以看到許多經典的卷積神經網絡。
附官網鏈接:https://pytorch.org/
這里簡單介紹一下經典的卷積神經發展歷程
1.首先可以說是卷積神經網絡的開山之作Alexnet(12年的奪冠之作)這里簡單說一下缺點 卷積核大,步長大,沒有填充層,大刀闊斧的提取特征,容易忽略一些重要的特征
2.第二個就是VGG網絡,它的卷積核大小是3*3,有一個優點是經過池化層之后,通道數翻倍,可以更多的保留一些特征,這是VGG的一個特點
在接下來的一段時間中,出現了一個問題,我們都知道,深度學習隨著訓練次數的不斷增加,效果應該是越來越好,但是這里出現了一個問題,研究發現隨著VGG網絡的不斷提高,效果卻沒有原來的好,這時候人們就認為,深度學習是不是只能發展到這里了,這時遇到了一個瓶頸。
3.接下來隨著殘差網絡(Resnet)的提出,解決了上面這個問題,這個網絡的優點是保留了原有的特征,假如經過卷積之后提取的特征還沒有原圖的好,這時候保留原有的特征,就會解決這一問題,下面就是resnet網絡模型
這是一些訓練對比:
二、遷移學習的目標
首先我們使用遷移學習的目標就是用人家訓練好的權重參數,偏置參數,來訓練我們的模型。
三、好處
深度學習要訓練的數據量是很大的,當我們數據量少時,我們訓練的權重參數就不會那么的好,所以這時候我們就可以使用別人訓練好的權重參數,偏置參數來使用,會使我們的模型準確率得到提高
四、步驟
遷移學習大致可以分為三步
1.加載模型
2.凍結層數
3.全連接層
五、代碼
這里使用的是resnet152
import torch
import torchvision as tv
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch
from torch.utils import data
from torch import optim
from torch.autograd import Variable
model_name='resnet'
featuer_extract=True
train_on_gpu=torch.cuda.is_available()
if not train_on_gpu:
print("沒有gpu")
else :
print("是gpu")
devic=torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
teature_extract=True
def set_paremeter_requires_grad(model,featuer_extract):
if featuer_extract:
for parm in model.parameters():
parm.requires_grad=False #不做訓練
def initialize_model(model_name,num_classes,featuer_extract,use_pretrained=True):
model_ft = None
input_size = 0
if model_name=="resnet":
model_ft=tv.models.resnet152(pretrained=use_pretrained)#下載模型
set_paremeter_requires_grad(model_ft,featuer_extract) #凍結層數
num_ftrs=model_ft.fc.in_features #改動全連接層
model_ft.fc=nn.Sequential(nn.Linear(num_ftrs,num_classes),
nn.LogSoftmax(dim=1))
input_size=224 #輸入維度
return model_ft,input_size
model_ft,iput_size=initialize_model(model_name,10,featuer_extract,use_pretrained=True)
model_ft=model_ft.to(devic)
params_to_updata=model_ft.parameters()
if featuer_extract:
params_to_updata=[]
for name,param in model_ft.named_parameters():
if param.requires_grad==True:
params_to_updata.append(param)
print("\t",name)
else:
for name,param in model_ft.parameters():
if param.requires_grad==True:
print("\t",name)
opt=optim.Adam(params_to_updata,lr=0.01)
loss=nn.NLLLoss()
if __name__ == '__main__':
transform = transforms.Compose([
# 圖像增強
transforms.Resize(1024),#裁剪
transforms.RandomHorizontalFlip(),#隨機水平翻轉
transforms.RandomCrop(224),#隨機裁剪
transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), #亮度
# 轉變為tensor 正則化
transforms.ToTensor(), #轉換格式
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # 歸一化處理
])
trainset = tv.datasets.CIFAR10(
root=r'E:\桌面\資料\cv3\數據集\cifar-10-batches-py',
train=True,
download=True,
transform=transform
)
trainloader = data.DataLoader(
trainset,
batch_size=8,
drop_last=True,
shuffle=True, # 亂序
num_workers=4,
)
testset = tv.datasets.CIFAR10(
root=r'E:\桌面\資料\cv3\數據集\cifar-10-batches-py',
train=False,
download=True,
transform=transform
)
testloader = data.DataLoader(
testset,
batch_size=8,
drop_last=True,
shuffle=False,
num_workers=4
)
for epoch in range(3):
running_loss=0
for index,data in enumerate(trainloader,0):
inputs, labels = data
inputs = inputs.to(devic)
labels = labels.to(devic)
inputs, labels = Variable(inputs), Variable(labels)
opt.zero_grad()
h=model_ft(inputs)
loss1=loss(h,labels)
loss1.backward()
opt.step()
h+=loss1.item()
if index%10==9:
avg_loss=loss1/10.
running_loss=0
print('avg_loss',avg_loss)
if index%100==99 :
correct=0
total=0
for data in testloader:
images,labels=data
outputs=model_ft(Variable(images.cuda()))
_,predicted=torch.max(outputs.cpu(),1)
total+=labels.size(0)
bool_tensor=(predicted==labels)
correct+=bool_tensor.sum()
print('1000張測試集中的準確率為%d %%'%(100*correct/total))
原文鏈接:https://blog.csdn.net/Lightismore/article/details/124476720
相關推薦
- 2023-02-18 Go語言IO輸入輸出底層原理及文件操作API_Golang
- 2022-06-12 C#實現基于任務的異步編程模式_C#教程
- 2022-07-02 webpack 配置file-loader統一字體打包文件輸出目錄后dist下仍然有字體打包文件
- 2023-01-20 實現.Net7下數據庫定時檢查的方法詳解_實用技巧
- 2024-07-15 Redis底層數據結構-鏈表
- 2022-06-22 android實現注冊登錄程序_Android
- 2022-11-02 Python+eval函數實現動態地計算數學表達式詳解_python
- 2022-10-05 Android?Flutter實現原理淺析_Android
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支