網(wǎng)站首頁 編程語言 正文
前言
目前機(jī)器學(xué)習(xí)框架有兩大方向,Pytorch和Tensorflow 2。對于機(jī)器學(xué)習(xí)的小白的我來說,直觀的感受是Tensorflow的框架更加傻瓜式,在這個框架下只需要定義神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)、輸入和輸出,然后直接使用其框架下的各種框架函數(shù)即可。而對于Pytorch來說,則使用者能操作、定義的細(xì)節(jié)更多,但與此同時(shí)使用難度也會更高。
通過各種資料也顯示,在學(xué)術(shù)研究范圍內(nèi),越來越多的人使用Pytorch,其實(shí)Tensorflow也不錯,但對于普通小白來說入手更快,應(yīng)用也更快。本著全面發(fā)展,多嘗試的心態(tài),開始Pytorch學(xué)習(xí)。
小編將從自身的理解習(xí)慣開始不斷更新這篇博文:
1.Pytorch簡介
Pytorch就是一個神經(jīng)網(wǎng)絡(luò)框架,使用Pytorch可以跳過很多不必要的底層工作,很多通用的方法、數(shù)據(jù)結(jié)構(gòu)都已經(jīng)實(shí)現(xiàn)供我們調(diào)用,從而可以讓我們將精力集中在改進(jìn)數(shù)據(jù)質(zhì)量、網(wǎng)絡(luò)結(jié)構(gòu)和評估方法上去。
使用和訓(xùn)練神經(jīng)網(wǎng)絡(luò)從思考順序上來說無非就三個階段:
1)構(gòu)思神經(jīng)網(wǎng)絡(luò)的輸入、輸出和網(wǎng)絡(luò)結(jié)構(gòu),其中輸入輸出非常關(guān)鍵。
2)訓(xùn)練數(shù)據(jù)集(粗糙的原始數(shù)據(jù))。
3)如何將訓(xùn)練數(shù)據(jù)集轉(zhuǎn)換成神經(jīng)網(wǎng)絡(luò)能夠接受并且能夠正確輸出的結(jié)構(gòu)。
4)訓(xùn)練神經(jīng)網(wǎng)絡(luò)并進(jìn)行預(yù)測。
2.Pytorch定義神經(jīng)網(wǎng)絡(luò)的輸入輸出和結(jié)構(gòu)
使用Pytorch定義神經(jīng)網(wǎng)絡(luò)非常通用的格式:
class NN(nn.Module):
def __init__(self):
super(NN,self).__init__()#繼承tOrch中已經(jīng)寫好的類,包含神經(jīng)網(wǎng)絡(luò)其余所有通用必要方法函數(shù)。
self.flatten=nn.Flatten()#加入展平函數(shù)。
self.net=nn.Sequential(#調(diào)用Sequential方法定義神經(jīng)網(wǎng)絡(luò)。
nn.Linear(100*3,100*3),
nn.ReLU(),
nn.Linear(100*3,100*3),
nn.ReLU(),
nn.Linear(100*3,27)
)
def forward(self,x):#自定義神經(jīng)網(wǎng)絡(luò)的前向傳播函數(shù),本文使用了正常的前向傳播函數(shù),但最終的結(jié)果給出三個輸出。
result=self.net(x)
r1=result[:9]
r2=result[9:18]
r3=result[18:27]
return [r1,r2,r3]
到這里基本上已經(jīng)定義了自己的神經(jīng)網(wǎng)絡(luò)了,輸入為100*3=200個數(shù)據(jù)、輸出為27個數(shù)據(jù)。那么問題來了,怎么把數(shù)據(jù)輸入進(jìn)去呢?
3.Pytorch神經(jīng)網(wǎng)絡(luò)的數(shù)據(jù)格式-tensor
對于編程小白、機(jī)器學(xué)習(xí)小白的我或者大家來說,tensor的直接定義不好理解。
tensor表面上只進(jìn)行了存儲,但實(shí)際上它包含了很多中方法,直接使用tensor.Method()調(diào)用相關(guān)方法即可,而省去了自己來定義函數(shù),再操作數(shù)據(jù)結(jié)構(gòu)。并且在Pytorch進(jìn)行訓(xùn)練時(shí),也會在其內(nèi)部調(diào)用這些方法,所以就需要我們使用這些數(shù)據(jù)結(jié)構(gòu)來作為Pytorch神經(jīng)網(wǎng)的輸入,并且神經(jīng)網(wǎng)絡(luò)的輸出也是tensor形式,numpy array 和 list 和 tensor 的轉(zhuǎn)換其實(shí)就是數(shù)據(jù)相同,但集成了不同方法的數(shù)據(jù)結(jié)構(gòu)。
那么下面就是輸入數(shù)據(jù)的定義。train_data和labels都是我們使用python方法寫出的list。
#train_data、labels都是list,經(jīng)過list->ndarray->tensor的轉(zhuǎn)換過程。
train_data=torch.tensor(np.array(train_data)).to(torch.float32).to(device)
labels=torch.tensor(np.array(labels)).to(torch.float32).to(device)
4.神經(jīng)網(wǎng)絡(luò)進(jìn)行預(yù)測
使用神經(jīng)網(wǎng)絡(luò)進(jìn)行預(yù)測(前向傳播)、計(jì)算損失函數(shù)、反向傳播更新梯度
1)進(jìn)行前向傳播
#train_data[0]即為訓(xùn)練數(shù)據(jù)的第一條輸入數(shù)據(jù)。
prediction=model(train_data[0])
2)計(jì)算損失
#定義優(yōu)化器
optim=torch.optim.SGD(model.parameters(),lr=1e-2,momentum=0.9)
# 定義自己的loss
loss=(prediction[0]-labels[0]).sum()+(prediction[1]-labels[1]).sum()+(prediction[2]-labels[2]).sum()
#反向傳播
optim.zero_grad()#清除上一次的靜態(tài)梯度,防止累加。
loss.backward()#計(jì)算反向傳播梯度。
optim.step()#進(jìn)行一次權(quán)值更新。
此處的計(jì)算損失和權(quán)值依據(jù)輸入數(shù)據(jù)更新一次的結(jié)果,由此加入一個循環(huán),便可以實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)的訓(xùn)練過程。
3) 訓(xùn)練網(wǎng)絡(luò)
在正式進(jìn)入訓(xùn)練網(wǎng)絡(luò)之前,我們還需要了解一個叫做Batch的東西,如果我們將數(shù)據(jù)一個一個送進(jìn)去訓(xùn)練,那么神經(jīng)網(wǎng)絡(luò)訓(xùn)練的速度將是十分緩慢的,因此我們每次可以丟進(jìn)去很多數(shù)據(jù)讓神經(jīng)網(wǎng)絡(luò)進(jìn)行預(yù)測,通過計(jì)算總體的損失就可以讓梯度更快地下降。但訓(xùn)練數(shù)據(jù)有時(shí)又很巨大,所以就需要將整個訓(xùn)練數(shù)據(jù)打包成一批一批的進(jìn)入訓(xùn)練,并重復(fù)若干次,每訓(xùn)練整個數(shù)據(jù)一次,會經(jīng)歷若干個batch,這一過程稱為一個epoch。
所以為了使網(wǎng)絡(luò)預(yù)測結(jié)果更快地收斂,即更快地訓(xùn)練神經(jīng)網(wǎng)絡(luò),我們需要首先對數(shù)據(jù)進(jìn)行打包。
import torch.utils.data as Data
bath=50#每批次大小
loader=Data.DataLoader(#制作數(shù)據(jù)集,只能由cpu讀取
dataset = train_data_set,
batch_size=bath,#每批次包含數(shù)據(jù)條數(shù)
shuffle=True,#是否打亂數(shù)據(jù)
num_workers=1,#多少個線程搬運(yùn)數(shù)據(jù)
)
然后,我們就可以進(jìn)行神經(jīng)網(wǎng)絡(luò)的訓(xùn)練了:
pstep=2#每個多少個批次就輸出一次結(jié)果
for epoch in range(1000):
running_loss=0.0
for step,(inps,labs) in enumerate(loader):
#取出數(shù)據(jù)并搬運(yùn)至GPU進(jìn)行計(jì)算。
labs=labs.to(device)
inps=inps.to(device)
outputs = model(inps)#將數(shù)據(jù)輸入進(jìn)去并進(jìn)行前向傳播
loss=loss_fn(outputs,labs)#損失函數(shù)的定義
optimizer.zero_grad()#清楚上一次的靜態(tài)梯度,防止累加。
loss.backward()#反向傳播更新梯度
optimizer.step()#進(jìn)行一次優(yōu)化。
running_loss += loss.item()#不加item()會造成內(nèi)存堆疊
size=len(labs)*3
correct=0
#print("outputs:",outputs.argmax(-1),"\nlabs:",labs.argmax(-1))
#逐個判斷計(jì)算準(zhǔn)確率
correct+=(outputs.argmax(-1)==labs.argmax(-1)).type(torch.float).sum().item()
if step % pstep == pstep-1: # print every 10 mini-batches
print('[%d, %5d] loss: %.3f correct:%.3f' %
(epoch + 1, step + 1, running_loss / pstep,correct/size))
if correct/size>1:#錯誤檢查
print("outputs:",outputs.argmax(-1),"\nlabs:",labs.argmax(-1),"\ncorrect:",correct,"\nSize:",size)
running_loss = 0.0
#保存模型
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")
原文鏈接:https://blog.csdn.net/weixin_52466509/article/details/127855142
相關(guān)推薦
- 2022-01-16 jQuery實(shí)現(xiàn)動畫效果和導(dǎo)航欄動態(tài)顯示
- 2022-01-19 【webpack5】webpack-dev-server 熱更新不能自動刷新瀏覽器
- 2022-09-10 Python?turtle庫(繪制螺旋正方形)_python
- 2023-02-09 go?sync?Once實(shí)現(xiàn)原理示例解析_Golang
- 2023-03-20 C#如何判斷.Net?Framework版本是否滿足軟件運(yùn)行需要的版本_C#教程
- 2022-07-14 使用react-activation實(shí)現(xiàn)keepAlive支持返回傳參_React
- 2022-09-22 提高接口并發(fā)量,防止崩潰
- 2022-06-11 利用Nginx實(shí)現(xiàn)URL重定向的簡單方法_nginx
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支