網站首頁 編程語言 正文
本節內容學習幫助大家梳理神經網絡訓練的架構。
一般我們訓練神經網絡有以下步驟:
- 導入庫
- 設置訓練參數的初始值
- 導入數據集并制作數據集
- 定義神經網絡架構
- 定義訓練流程
- 訓練模型
推薦文章:
python實現可視化大屏
分享4款 Python 自動數據分析神器
以下,我就將上述步驟使用代碼進行注釋講解:
1 導入庫
import torch from torch import nn from torch.nn import functional as F from torch import optim from torch.utils.data import DataLoader, DataLoader import torchvision import torchvision.transforms as transforms
2 設置初始值
# 學習率 lr = 0.15 # 優化算法參數 gamma = 0.8 # 每次小批次訓練個數 bs = 128 # 整體數據循環次數 epochs = 10
3 導入并制作數據集
本次我們使用FashionMNIST
圖像數據集,每個圖像是一個28*28的像素數組,共有10個衣物類別,比如連衣裙、運動鞋、包等。
注:初次運行下載需要等待較長時間。
# 導入數據集 mnist = torchvision.datasets.FashionMNIST( ? ? root = './Datastes' ? ? , train = True ? ? , download = True ? ? , transform = transforms.ToTensor()) ? ?? # 制作數據集 batchdata = DataLoader(mnist ? ? ? ? ? ? ? ? ? ? ? ?, batch_size = bs ? ? ? ? ? ? ? ? ? ? ? ?, shuffle = True ? ? ? ? ? ? ? ? ? ? ? ?, drop_last = False)
我們可以對數據進行檢查:
for x, y in batchdata: ? ? print(x.shape) ? ? print(y.shape) ? ? break # torch.Size([128, 1, 28, 28]) # torch.Size([128])
可以看到一個batch
中有128個樣本,每個樣本的維度是1*28*28。
之后我們確定模型的輸入維度與輸出維度:
# 輸入的維度 input_ = mnist.data[0].numel() # 784 # 輸出的維度 output_ = len(mnist.targets.unique()) # 10
4 定義神經網絡架構
先使用一個128個神經元的全連接層,然后用relu激活函數,再將其結果映射到標簽的維度,并使用softmax
進行激活。
# 定義神經網絡架構 class Model(nn.Module): ? ? def __init__(self, in_features, out_features): ? ? ? ? super().__init__() ? ? ? ? self.linear1 = nn.Linear(in_features, 128, bias = True) ? ? ? ? self.output = nn.Linear(128, out_features, bias = True) ? ?? ? ? def forward(self, x): ? ? ? ? x = x.view(-1, 28*28) ? ? ? ? sigma1 = torch.relu(self.linear1(x)) ? ? ? ? sigma2 = F.log_softmax(self.output(sigma1), dim = -1) ? ? ? ? return sigma2
5 定義訓練流程
在實際應用中,我們一般會將訓練模型部分封裝成一個函數,而這個函數可以繼續細分為以下幾步:
- 定義損失函數與優化器
- 完成向前傳播
- 計算損失
- 反向傳播
- 梯度更新
- 梯度清零
在此六步核心操作的基礎上,我們通常還需要對模型的訓練進度、損失值與準確度進行監視。
注釋代碼如下:
# 封裝訓練模型的函數 def fit(net, batchdata, lr, gamma, epochs): # 參數:模型架構、數據、學習率、優化算法參數、遍歷數據次數 ? ? # 5.1 定義損失函數 ? ? criterion = nn.NLLLoss() ? ? # 5.1 定義優化算法 ? ? opt = optim.SGD(net.parameters(), lr = lr, momentum = gamma) ? ?? ? ? # 監視進度:循環之前,一個樣本都沒有看過 ? ? samples = 0 ? ? # 監視準確度:循環之前,預測正確的個數為0 ? ? corrects = 0 ? ?? ? ? # 全數據訓練幾次 ? ? for epoch in range(epochs): ? ? ? ? # 對每個batch進行訓練 ? ? ? ? for batch_idx, (x, y) in enumerate(batchdata): ? ? ? ? ? ? # 保險起見,將標簽轉為1維,與樣本對齊 ? ? ? ? ? ? y = y.view(x.shape[0]) ? ? ? ? ? ?? ? ? ? ? ? ? # 5.2 正向傳播 ? ? ? ? ? ? sigma = net.forward(x) ? ? ? ? ? ? # 5.3 計算損失 ? ? ? ? ? ? loss = criterion(sigma, y) ? ? ? ? ? ? # 5.4 反向傳播 ? ? ? ? ? ? loss.backward() ? ? ? ? ? ? # 5.5 更新梯度 ? ? ? ? ? ? opt.step() ? ? ? ? ? ? # 5.6 梯度清零 ? ? ? ? ? ? opt.zero_grad() ? ? ? ? ? ?? ? ? ? ? ? ? # 監視進度:每訓練一個batch,模型見過的數據就會增加x.shape[0] ? ? ? ? ? ? samples += x.shape[0] ? ? ? ? ? ?? ? ? ? ? ? ? # 求解準確度:全部判斷正確的樣本量/已經看過的總樣本量 ? ? ? ? ? ? # 得到預測標簽 ? ? ? ? ? ? yhat = torch.max(sigma, -1)[1] ? ? ? ? ? ? # 將正確的加起來 ? ? ? ? ? ? corrects += torch.sum(yhat == y) ? ? ? ? ? ?? ? ? ? ? ? ? # 每200個batch和最后結束時,打印模型的進度 ? ? ? ? ? ? if (batch_idx + 1) % 200 == 0 or batch_idx == (len(batchdata) - 1): ? ? ? ? ? ? ? ? # 監督模型進度 ? ? ? ? ? ? ? ? print("Epoch{}:[{}/{} {: .0f}%], Loss:{:.6f}, Accuracy:{:.6f}".format( ? ? ? ? ? ? ? ? ? ? epoch + 1 ? ? ? ? ? ? ? ? ? ? , samples ? ? ? ? ? ? ? ? ? ? , epochs*len(batchdata.dataset) ? ? ? ? ? ? ? ? ? ? , 100*samples/(epochs*len(batchdata.dataset)) ? ? ? ? ? ? ? ? ? ? , loss.data.item() ? ? ? ? ? ? ? ? ? ? , float(100.0*corrects/samples)))
6 訓練模型
# 設置隨機種子 torch.manual_seed(51) # 實例化模型 net = Model(input_, output_) # 訓練模型 fit(net, batchdata, lr, gamma, epochs) # Epoch1:[25600/600000 ?4%], Loss:0.524430, Accuracy:69.570312 # Epoch1:[51200/600000 ?9%], Loss:0.363422, Accuracy:74.984375 # ...... # Epoch10:[600000/600000 ?100%], Loss:0.284664, Accuracy:85.771835
現在我們已經用Pytorch
訓練了最基礎的神經網絡,并且可以查看其訓練成果。大家可以將代碼復制進行運行!
雖然沒有用到復雜的模型,但是我們在每次建模時的基本思想都是一致的
原文鏈接:https://blog.csdn.net/weixin_38037405/article/details/123157702
- 上一篇:C#設計模式之工廠模式_C#教程
- 下一篇:C#設計模式之單例模式_C#教程
相關推薦
- 2022-10-11 tidb-系統內核調優及對比
- 2022-04-30 DataGridView實現點擊列頭升序和降序排序_C#教程
- 2024-01-15 jquery獲取dom元素身上的綁定事件
- 2022-06-02 Python配置文件yaml的用法詳解_python
- 2022-06-04 如何通過一篇文章了解Python中的生成器_python
- 2022-07-13 ELK 日志分析系統的部署
- 2023-01-21 C++實現逆波蘭表達式的例題詳解_C 語言
- 2022-09-01 Docker鏡像的遷移與備份及Dockerflie?使用方法詳解_docker
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支