網(wǎng)站首頁 編程語言 正文
1.引言
我們之前已經(jīng)介紹了神經(jīng)網(wǎng)絡(luò)的基本知識,神經(jīng)網(wǎng)絡(luò)的主要作用就是預(yù)測與分類,現(xiàn)在讓我們來搭建第一個用于擬合回歸的神經(jīng)網(wǎng)絡(luò)吧。
2.神經(jīng)網(wǎng)絡(luò)搭建
2.1 準(zhǔn)備工作
要搭建擬合神經(jīng)網(wǎng)絡(luò)并繪圖我們需要使用python的幾個庫。
import torch import torch.nn.functional as F import matplotlib.pyplot as plt x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1) y = x.pow(3) + 0.2 * torch.rand(x.size())
?既然是擬合,我們當(dāng)然需要一些數(shù)據(jù)啦,我選取了在區(qū)間??內(nèi)的100個等間距點(diǎn),并將它們排列成三次函數(shù)的圖像。
2.2 搭建網(wǎng)絡(luò)
我們定義一個類,繼承了封裝在torch中的一個模塊,我們先分別確定輸入層、隱藏層、輸出層的神經(jīng)元數(shù)目,繼承父類后再使用torch中的.nn.Linear()函數(shù)進(jìn)行輸入層到隱藏層的線性變換,隱藏層也進(jìn)行線性變換后傳入輸出層predict,接下來定義前向傳播的函數(shù)forward(),使用relu()作為激活函數(shù),最后輸出predict()結(jié)果即可。
class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) return self.predict(x) net = Net(1, 20, 1) print(net) optimizer = torch.optim.Adam(net.parameters(), lr=0.2) loss_func = torch.nn.MSELoss()
網(wǎng)絡(luò)的框架搭建完了,然后我們傳入三層對應(yīng)的神經(jīng)元數(shù)目再定義優(yōu)化器,這里我選取了Adam而隨機(jī)梯度下降(SGD),因?yàn)樗荢GD的優(yōu)化版本,效果在大部分情況下比SGD好,我們要傳入這個神經(jīng)網(wǎng)絡(luò)的參數(shù)(parameters),并定義學(xué)習(xí)率(learning rate),學(xué)習(xí)率通常選取小于1的數(shù),需要憑借經(jīng)驗(yàn)并不斷調(diào)試。最后我們選取均方差法(MSE)來計(jì)算損失(loss)。
2.3 訓(xùn)練網(wǎng)絡(luò)
接下來我們要對我們搭建好的神經(jīng)網(wǎng)絡(luò)進(jìn)行訓(xùn)練,我訓(xùn)練了2000輪(epoch),先更新結(jié)果prediction再計(jì)算損失,接著清零梯度,然后根據(jù)loss反向傳播(backward),最后進(jìn)行優(yōu)化,找出最優(yōu)的擬合曲線。
for t in range(2000): prediction = net(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step()
3.效果
使用如下繪圖的代碼展示效果。
for t in range(2000): prediction = net(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() if t % 5 == 0: plt.cla() plt.scatter(x.data.numpy(), y.data.numpy(), s=10) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2) plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'}) plt.pause(0.1) plt.ioff() plt.show()
最后的結(jié)果:?
4. 完整代碼
import torch import torch.nn.functional as F import matplotlib.pyplot as plt x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1) y = x.pow(3) + 0.2 * torch.rand(x.size()) class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) return self.predict(x) net = Net(1, 20, 1) print(net) optimizer = torch.optim.Adam(net.parameters(), lr=0.2) loss_func = torch.nn.MSELoss() plt.ion() for t in range(2000): prediction = net(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() if t % 5 == 0: plt.cla() plt.scatter(x.data.numpy(), y.data.numpy(), s=10) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2) plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'}) plt.pause(0.1) plt.ioff() plt.show()
原文鏈接:https://blog.csdn.net/ZDDWLIG/article/details/123488056
相關(guān)推薦
- 2022-07-22 mybatis一級緩存和二級緩存理解與區(qū)別
- 2023-02-17 Go語言Gin處理響應(yīng)方式詳解_Golang
- 2022-05-19 徹底解決No?module?named?‘torch_geometric‘報(bào)錯的辦法_python
- 2022-09-03 C#實(shí)現(xiàn)工廠方法模式_C#教程
- 2022-08-19 存儲引擎的應(yīng)用場景
- 2022-05-04 分享3個非常實(shí)用的?Python?模塊_python
- 2022-05-25 pytorch?hook?鉤子函數(shù)的用法_python
- 2022-09-12 在CMD窗口中調(diào)用python函數(shù)的實(shí)現(xiàn)_python
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支