網(wǎng)站首頁 編程語言 正文
參考資料:
Pytorch這個(gè)深度學(xué)習(xí)框架在設(shè)計(jì)的時(shí)候嵌入了非常豐富的繼承機(jī)制。在通用的深度學(xué)習(xí)算法中使用到的組件其實(shí)都繼承于某一個(gè)父類,比如:Dataset,DataLoader,Model等其實(shí)都蘊(yùn)含了一個(gè)繼承機(jī)制。這篇隨筆打算梳理并剖析一下Pytorch里的這樣一種繼承現(xiàn)象。請(qǐng)注意,繼承后的子類的構(gòu)造方法第一行一定要調(diào)用super()方法哦。
torch.nn.Module
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
平時(shí)我們?cè)谏疃葘W(xué)習(xí)中提到的Model其實(shí)就是繼承自torch.nn.Module。最重要且繼承后必須重寫的方法是forward,這個(gè)方法直接規(guī)定Model的前向運(yùn)算方式。此外還有一些預(yù)定義的方法,比較重要的是:
- net.apply
- net.cuda
- net.train
- net.eval
- net.load_state_dict
- net.zero_grad
torch.utils.data.Dataset
官網(wǎng)并沒有給這個(gè)類示例,可能是覺得這個(gè)類比較簡(jiǎn)單。正如描述中所說,torch.utils.data.Dataset是來handle鍵值對(duì)形式的數(shù)據(jù)格式的。我們必須實(shí)現(xiàn)兩個(gè)函數(shù),__getitem__和__len__。前者輸入索引index返回對(duì)應(yīng)的數(shù)據(jù)(和label),后者返回?cái)?shù)據(jù)集總的大小(index的上限)。
補(bǔ)充一句,在官網(wǎng)的Doc中torch.utils.data.Dataset下面就是torch.utils.data.IterableDataset,這個(gè)數(shù)據(jù)集格式和上面Dataset的區(qū)別在于它是來handle可迭代的數(shù)據(jù)集類型。其只需要重寫一個(gè)__iter__函數(shù),留待日后有需要的時(shí)候研究。
torchvision.transforms
這里跑題提一下torchvision里面經(jīng)常用到的transforms,它本質(zhì)也是nn.Module(不信看源碼),其方便之處在于提供了豐富的內(nèi)置處理圖片的方法(transforms變換)。并且可以通過transforms.Compose方法把多個(gè)transform串序并到一起(類似nn.Sequential)。所以在繼承一個(gè)torch.utils.data.Dataset的時(shí)候不妨多利用transforms哦(explicitly specify transform)。
torch.utils.data.DataLoader
從形式上來看,DataLoader是Dataset套的一層包裝;從功能上來看,DataLoader才是最終提供給Model數(shù)據(jù)的人。這個(gè)組件基本不涉及繼承機(jī)制(很少人去改寫這個(gè)類),因此略過。
torch.nn.modules.loss
說完了Dataset和Model,不得不提的就是損失函數(shù)了,從torch.nn.modules.loss可以看出,所有的loss其實(shí)沒啥特別的,說白了也是一個(gè)nn.Module。只不過它的forward方法比較特殊,Model的forward方法是給他一個(gè)data_tensor,而Loss的forward方法是給他一個(gè)target_tensor和一個(gè)(Model預(yù)測(cè)的)input_tensor,返回值一般是一個(gè)常數(shù)。
torch.optim
torch.optim這個(gè)包下預(yù)置了很多Optimizer比如SGD,Adam。其用法如下:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
需要注意的是如果網(wǎng)絡(luò)要在GPU上訓(xùn)練,則optimizer和model綁定應(yīng)該在model轉(zhuǎn)移到GPU上之后。
torch.optim.lr_scheduler
深度學(xué)習(xí)在訓(xùn)練時(shí)一個(gè)動(dòng)態(tài)衰減的學(xué)習(xí)率是比較理想的。torch.optim.lr_scheduler提供了這樣一個(gè)功能。其用法如下:
model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler.step()
注意scheduler.step要在optimizer.step之后。
原文鏈接:https://blog.csdn.net/weixin_43590796/article/details/121122853
相關(guān)推薦
- 2022-05-18 Golang?并發(fā)下的問題定位及解決方案_Golang
- 2023-10-11 Mybatis-Plus條件構(gòu)造器的select
- 2022-09-09 C++代碼和可執(zhí)行程序在x86和arm上的區(qū)別介紹_C 語言
- 2023-10-12 ant-design的Input輸入框取消選擇后的藍(lán)色背景以及如何取消提示
- 2024-07-13 IDEA無法使用@WebServlet()注解
- 2022-08-15 使用element中el-table設(shè)置type=“expand“展開行隱藏小箭頭的方法(列表單選、
- 2022-11-07 python中openpyxl庫用法詳解_python
- 2023-02-27 c++數(shù)組排序的5種方法實(shí)例代碼_C 語言
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支