網站首頁 編程語言 正文
學習前言
我發現不僅有很多的Keras模型,還有很多的PyTorch模型,還是學學Pytorch吧,我也想了解以下tensor到底是個啥。
PyTorch中的重要基礎函數
1、class Net(torch.nn.Module)神經網絡的構建:
PyTorch中神經網絡的構建和Tensorflow的不一樣,它需要用一個類來進行構建(后面還可以用與Keras類似的Sequential模型構建),當然基礎還是用類構建,這個類需要繼承PyTorch中的神經網絡模型,torch.nn.Module,具體構建方式如下:
# 繼承torch.nn.Module模型
class Net(torch.nn.Module):
# 重載初始化函數(我忘了這個是不是叫重載)
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
# Applies a linear transformation to the incoming data: :math:y = xA^T + b
# 全連接層,公式為y = xA^T + b
# 在初始化的同時構建兩個全連接層(也就是一個隱含層)
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
# forward函數用于構建前向傳遞的過程
def forward(self, x):
# 隱含層的輸出
hidden_layer = functional.relu(self.hidden(x))
# 實際的輸出
output_layer = self.predict(hidden_layer)
return output_layer
該部分構建了一個含有一層隱含層的神經網絡,隱含層神經元個數為n_hidden。
在建立了上述的類后,就可以通過如下函數建立神經網絡:
net = Net(n_feature=1, n_hidden=10, n_output=1)
2、optimizer優化器
optimizer用于構建模型的優化器,與tensorflow中優化器的意義相同,PyTorch的優化器在前綴為torch.optim的庫中。
優化器需要傳入net網絡的參數。
具體使用方式如下:
# torch.optim是優化器模塊
# Adam可以改成其它優化器,如SGD、RMSprop等
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
3、loss損失函數定義
loss用于定義神經網絡訓練的損失函數,常用的損失函數是均方差損失函數(回歸)和交叉熵損失函數(分類)。
具體使用方式如下:
# 均方差lossloss_func = torch.nn.MSELoss()
4、訓練過程
訓練過程分為三個步驟:
1、利用網絡預測結果。
prediction = net(x)
2、利用預測的結果與真實值對比生成loss。
loss = loss_func(prediction, y)
3、進行反向傳遞(該部分有三步)。
# 均方差loss
# 反向傳遞步驟
# 1、初始化梯度
optimizer.zero_grad()
# 2、計算梯度
loss.backward()
# 3、進行optimizer優化
optimizer.step()
全部代碼
這是一個簡單的回歸預測模型。
import torch
from torch.autograd import Variable
import torch.nn.functional as functional
import matplotlib.pyplot as plt
import numpy as np
# x的shape為(100,1)
x = torch.from_numpy(np.linspace(-1,1,100).reshape([100,1])).type(torch.FloatTensor)
# y的shape為(100,1)
y = torch.sin(x) + 0.2*torch.rand(x.size())
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
# Applies a linear transformation to the incoming data: :math:y = xA^T + b
# 全連接層,公式為y = xA^T + b
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
def forward(self, x):
# 隱含層的輸出
hidden_layer = functional.relu(self.hidden(x))
output_layer = self.predict(hidden_layer)
return output_layer
# 類的建立
net = Net(n_feature=1, n_hidden=10, n_output=1)
# torch.optim是優化器模塊
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# 均方差loss
loss_func = torch.nn.MSELoss()
for t in range(1000):
prediction = net(x)
loss = loss_func(prediction, y)
# 反向傳遞步驟
# 1、初始化梯度
optimizer.zero_grad()
# 2、計算梯度
loss.backward()
# 3、進行optimizer優化
optimizer.step()
if t & 50 == 0:
print("The loss is",loss.data.numpy())
運行結果為:
The loss is 0.27913737
The loss is 0.2773982
The loss is 0.27224126
…………
The loss is 0.0035993527
The loss is 0.0035974088
The loss is 0.0035967692
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/101790418
相關推薦
- 2022-10-22 python?中的?super詳解_python
- 2023-01-28 Flutter?Widget移動UI框架使用Material和密匙Key實戰_Android
- 2022-05-12 小程序自定義日期組件,不顯示今日之后的日期
- 2022-04-03 用Python實現控制電腦鼠標_python
- 2022-01-16 ES6新增聲明格式、變量解構賦值及模板字符串
- 2022-03-03 實現不需要手動浮空瀏覽器緩存,程序可以獲取最新版本
- 2022-09-04 react?表單數據形式配置化設計_React
- 2022-08-10 詳細聊一聊algorithm中的排序算法_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支