網站首頁 編程語言 正文
前言
目前機器學習框架有兩大方向,Pytorch和Tensorflow 2。對于機器學習的小白的我來說,直觀的感受是Tensorflow的框架更加傻瓜式,在這個框架下只需要定義神經網絡的結構、輸入和輸出,然后直接使用其框架下的各種框架函數即可。而對于Pytorch來說,則使用者能操作、定義的細節更多,但與此同時使用難度也會更高。
通過各種資料也顯示,在學術研究范圍內,越來越多的人使用Pytorch,其實Tensorflow也不錯,但對于普通小白來說入手更快,應用也更快。本著全面發展,多嘗試的心態,開始Pytorch學習。
小編將從自身的理解習慣開始不斷更新這篇博文:
1.Pytorch簡介
Pytorch就是一個神經網絡框架,使用Pytorch可以跳過很多不必要的底層工作,很多通用的方法、數據結構都已經實現供我們調用,從而可以讓我們將精力集中在改進數據質量、網絡結構和評估方法上去。
使用和訓練神經網絡從思考順序上來說無非就三個階段:
1)構思神經網絡的輸入、輸出和網絡結構,其中輸入輸出非常關鍵。
2)訓練數據集(粗糙的原始數據)。
3)如何將訓練數據集轉換成神經網絡能夠接受并且能夠正確輸出的結構。
4)訓練神經網絡并進行預測。
2.Pytorch定義神經網絡的輸入輸出和結構
使用Pytorch定義神經網絡非常通用的格式:
class NN(nn.Module):
def __init__(self):
super(NN,self).__init__()#繼承tOrch中已經寫好的類,包含神經網絡其余所有通用必要方法函數。
self.flatten=nn.Flatten()#加入展平函數。
self.net=nn.Sequential(#調用Sequential方法定義神經網絡。
nn.Linear(100*3,100*3),
nn.ReLU(),
nn.Linear(100*3,100*3),
nn.ReLU(),
nn.Linear(100*3,27)
)
def forward(self,x):#自定義神經網絡的前向傳播函數,本文使用了正常的前向傳播函數,但最終的結果給出三個輸出。
result=self.net(x)
r1=result[:9]
r2=result[9:18]
r3=result[18:27]
return [r1,r2,r3]
到這里基本上已經定義了自己的神經網絡了,輸入為100*3=200個數據、輸出為27個數據。那么問題來了,怎么把數據輸入進去呢?
3.Pytorch神經網絡的數據格式-tensor
對于編程小白、機器學習小白的我或者大家來說,tensor的直接定義不好理解。
tensor表面上只進行了存儲,但實際上它包含了很多中方法,直接使用tensor.Method()調用相關方法即可,而省去了自己來定義函數,再操作數據結構。并且在Pytorch進行訓練時,也會在其內部調用這些方法,所以就需要我們使用這些數據結構來作為Pytorch神經網的輸入,并且神經網絡的輸出也是tensor形式,numpy array 和 list 和 tensor 的轉換其實就是數據相同,但集成了不同方法的數據結構。
那么下面就是輸入數據的定義。train_data和labels都是我們使用python方法寫出的list。
#train_data、labels都是list,經過list->ndarray->tensor的轉換過程。
train_data=torch.tensor(np.array(train_data)).to(torch.float32).to(device)
labels=torch.tensor(np.array(labels)).to(torch.float32).to(device)
4.神經網絡進行預測
使用神經網絡進行預測(前向傳播)、計算損失函數、反向傳播更新梯度
1)進行前向傳播
#train_data[0]即為訓練數據的第一條輸入數據。
prediction=model(train_data[0])
2)計算損失
#定義優化器
optim=torch.optim.SGD(model.parameters(),lr=1e-2,momentum=0.9)
# 定義自己的loss
loss=(prediction[0]-labels[0]).sum()+(prediction[1]-labels[1]).sum()+(prediction[2]-labels[2]).sum()
#反向傳播
optim.zero_grad()#清除上一次的靜態梯度,防止累加。
loss.backward()#計算反向傳播梯度。
optim.step()#進行一次權值更新。
此處的計算損失和權值依據輸入數據更新一次的結果,由此加入一個循環,便可以實現神經網絡的訓練過程。
3) 訓練網絡
在正式進入訓練網絡之前,我們還需要了解一個叫做Batch的東西,如果我們將數據一個一個送進去訓練,那么神經網絡訓練的速度將是十分緩慢的,因此我們每次可以丟進去很多數據讓神經網絡進行預測,通過計算總體的損失就可以讓梯度更快地下降。但訓練數據有時又很巨大,所以就需要將整個訓練數據打包成一批一批的進入訓練,并重復若干次,每訓練整個數據一次,會經歷若干個batch,這一過程稱為一個epoch。
所以為了使網絡預測結果更快地收斂,即更快地訓練神經網絡,我們需要首先對數據進行打包。
import torch.utils.data as Data
bath=50#每批次大小
loader=Data.DataLoader(#制作數據集,只能由cpu讀取
dataset = train_data_set,
batch_size=bath,#每批次包含數據條數
shuffle=True,#是否打亂數據
num_workers=1,#多少個線程搬運數據
)
然后,我們就可以進行神經網絡的訓練了:
pstep=2#每個多少個批次就輸出一次結果
for epoch in range(1000):
running_loss=0.0
for step,(inps,labs) in enumerate(loader):
#取出數據并搬運至GPU進行計算。
labs=labs.to(device)
inps=inps.to(device)
outputs = model(inps)#將數據輸入進去并進行前向傳播
loss=loss_fn(outputs,labs)#損失函數的定義
optimizer.zero_grad()#清楚上一次的靜態梯度,防止累加。
loss.backward()#反向傳播更新梯度
optimizer.step()#進行一次優化。
running_loss += loss.item()#不加item()會造成內存堆疊
size=len(labs)*3
correct=0
#print("outputs:",outputs.argmax(-1),"\nlabs:",labs.argmax(-1))
#逐個判斷計算準確率
correct+=(outputs.argmax(-1)==labs.argmax(-1)).type(torch.float).sum().item()
if step % pstep == pstep-1: # print every 10 mini-batches
print('[%d, %5d] loss: %.3f correct:%.3f' %
(epoch + 1, step + 1, running_loss / pstep,correct/size))
if correct/size>1:#錯誤檢查
print("outputs:",outputs.argmax(-1),"\nlabs:",labs.argmax(-1),"\ncorrect:",correct,"\nSize:",size)
running_loss = 0.0
#保存模型
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")
原文鏈接:https://blog.csdn.net/weixin_52466509/article/details/127855142
相關推薦
- 2022-04-30 Winform項目中TextBox控件DataBindings屬性_C#教程
- 2023-02-23 Android中URLEncoder空格被轉碼為"+"號的處理辦法_Android
- 2022-06-18 Go語言學習之時間函數使用詳解_Golang
- 2023-05-22 Pytorch怎樣保存訓練好的模型_python
- 2023-04-24 Android布局控件View?ViewRootImpl?WindowManagerService關
- 2023-02-03 TypeScript?中?as?const使用介紹_其它
- 2022-10-02 React構建組件的幾種方式及區別_React
- 2022-06-19 python繪制散點圖和折線圖的方法_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支