網(wǎng)站首頁 編程語言 正文
1. 加載數(shù)據(jù)集
這次我們搭建一個小小的多層線性網(wǎng)絡(luò)對糖尿病的病例進行分類
首先先導入需要的庫文件
先來看看我們的數(shù)據(jù)集
觀察可以發(fā)現(xiàn),前八列是我們的feature ,根據(jù)這八個特征可以判斷出病人是否得了糖尿病。所以最后一列是1,0 的一個二分類問題
我們使用numpy 去導入數(shù)據(jù)集,delimiter 是定義分隔符,這里我們用逗號(,)分割
將前八列的特征放到我們的x_data里面,作為特征輸入,最后一列放到y(tǒng)_data作為label
Tip :這里y_data 里面的 [-1] 中括號不可以省略,否則y_data會變成向量的形式
如果不習慣這種寫法,可以用view改變一下形狀就行
y_data = torch.from_numpy(xy[:,-1]).view(-1,1) #將y_data 的代碼改成這樣就可以了
下面是xy , x_data , y_data 打印出前兩行的結(jié)果
2. 搭建網(wǎng)絡(luò)+優(yōu)化器
搭建網(wǎng)絡(luò)的時候,要保證兩層網(wǎng)絡(luò)之間的維數(shù)能對應(yīng)上
首先第一層的時候,因為前八列作為我們的x_data ,也就是說我們輸入的特征是 8 維度的,那么由于 y = x * wT + b ,因為輸入數(shù)據(jù)的x是(n * 8) 的,而我們定義的y維度是(n * 6) ,所以wT的維度應(yīng)該是(8,6)
這里不需要知道啥時候轉(zhuǎn)置,啥時候不轉(zhuǎn)置之類的,只要滿足線性的方程y = w*x+b,并且維度一致就行了。因為不管是轉(zhuǎn)置,或者w和x誰在前,只是為了保證滿足矩陣相乘而已
一個小的技巧就是:只需要看輸入特征是多少,然后保證第一層第一個參數(shù)對應(yīng)就行了,然后第一層第二個參數(shù)是想輸出的維度。其次是第二層的第一個參數(shù)對應(yīng)第一層第二個參數(shù),以此類推....
我們采用的激活函數(shù)是ReLU , 由于是二元分類,最后一個網(wǎng)絡(luò)的輸出我們采用sigmoid輸出
接下來,搭建實例化我們的網(wǎng)絡(luò),然后建立優(yōu)化器
這里我們選擇SGD隨機梯度下降算法,學習率設(shè)置為0.01
3. 訓練網(wǎng)絡(luò)
訓練網(wǎng)絡(luò)的過程較為簡單,大概的過程為
1. 計算預(yù)測值
2. 計算損失函數(shù)
3. 反向傳播,之前要進行梯度清零
4. 梯度更新
5. 重復(fù)這個過程,epoch 為所有樣本計算一次的周期,這次讓epoch 迭代1000次
4. 代碼
import torch.nn as nn # 神經(jīng)網(wǎng)絡(luò)庫
import matplotlib.pyplot as plt # 繪圖
import torch # 張量
from torch import optim # 優(yōu)化器庫
import numpy as np # 數(shù)據(jù)處理
xy = np.loadtxt('./diabetes.csv.gz',delimiter=',',dtype=np.float32) # 加載數(shù)據(jù)集
x_data = torch.from_numpy(xy[:,:-1]) # 所有行,除了最后一列的元素
y_data = torch.from_numpy(xy[:,-1]).view(-1,1) # -1也能拿出來是向量,但是[-1]會保證拿出來的是個矩陣
epoch_list =[]
loss_list = []
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.linear1 = nn.Linear(8,6)
self.linear2 = nn.Linear(6,3)
self.linear3 = nn.Linear(3,1)
self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU()
def forward(self,x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model = Model()
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(),lr =0.01)
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred,y_data) # 計算損失
if epoch % 100 ==0: # 每隔100次打印一下
print(epoch,loss.item())
#back propagation
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向傳播
optimizer.step() # 梯度更新
epoch_list.append(epoch)
loss_list.append(loss.item())
plt.plot(epoch_list,loss_list)
plt.show()
輸出結(jié)果為:
原文鏈接:https://blog.csdn.net/qq_44886601/article/details/127347389
相關(guān)推薦
- 2022-11-25 Python利用memory_profiler實現(xiàn)內(nèi)存分析_python
- 2023-05-29 scipy稀疏數(shù)組coo_array的實現(xiàn)_python
- 2022-07-04 C#使用System.Buffer以字節(jié)數(shù)組Byte[]操作基元類型數(shù)據(jù)_C#教程
- 2023-02-17 python連接讀寫操作redis的完整代碼實例_python
- 2022-08-22 C#使用MSTest進行單元測試_C#教程
- 2022-07-25 C++文件的操作及小實驗示例代碼詳解_C 語言
- 2022-11-22 python枚舉類型定義與使用講解_python
- 2022-09-03 Docker?Buildx構(gòu)建多平臺鏡像的實現(xiàn)_docker
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學習環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支