網站首頁 編程語言 正文
一、實現過程
1、準備數據
本文數據采取文獻[1]給出的數據集,該數據集前8列為特征,最后1列為標簽(0/1)。本模型使用pandas處理該數據集,需要注意的是,原始數據集沒有特征名稱,需要自己在第一行添加上去,否則,pandas會把第一行的數據當成特征名稱處理,從而影響最后的分類效果。
代碼如下:
# 1、準備數據 import torch import pandas as pd import numpy as np xy = pd.read_csv('G:/datasets/diabetes/diabetes.csv',dtype=np.float32)?? ?# 文件路徑 x_data = torch.from_numpy(xy.values[:,:-1]) y_data = torch.from_numpy(xy.values[:,[-1]])
2、設計模型
本文采取文獻[1]的思路,激活函數使用ReLU,最后一層使用Sigmoid
函數,
代碼如下:
class Model(torch.nn.Module): ? ? def __init__(self): ? ? ? ? super(Model,self).__init__() ? ? ? ? self.linear1 = torch.nn.Linear(8,6) ? ? ? ? self.linear2 = torch.nn.Linear(6,4) ? ? ? ? self.linear3 = torch.nn.Linear(4,1) ? ? ? ? self.activate = torch.nn.ReLU() ? ?? ? ? def forward(self, x): ? ? ? ? x = self.activate(self.linear1(x)) ? ? ? ? x = self.activate(self.linear2(x)) ? ? ? ? x = torch.sigmoid(self.linear3(x)) ? ? ? ? return x model = Model()
將模型和數據加載到GPU上,代碼如下:
### 將模型和訓練數據加載到GPU上 # 模型加載到GPU上 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.to(device) # 數據加載到GPU上 x = x_data.to(device) y = y_data.to(device)
3、構造損失函數和優化器 criterion = torch.nn.BCELoss(reduction='mean') optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
4、訓練過程
epoch_list = [] loss_list = [] epochs = 10000 for epoch in range(epochs): ? ? # Forward ? ? y_pred = model(x) ? ? loss = criterion(y_pred, y) ? ? print(epoch, loss) ? ? epoch_list.append(epoch) ? ? loss_list.append(loss.data.item()) ? ? # Backward ? ? optimizer.zero_grad() ? ? loss.backward() ? ? # Update ? ? optimizer.step()
5、結果展示
查看各個層的權重和偏置:
model.linear1.weight,model.linear1.bias model.linear2.weight,model.linear2.bias model.linear3.weight,model.linear3.bias
損失值隨迭代次數的變化曲線:
# 繪圖展示 plt.plot(epoch_list,loss_list,'b') plt.xlabel('epoch') plt.ylabel('loss') plt.grid() plt.show()
最終的損失和準確率:
# 準確率 y_pred_label = torch.where(y_pred.data.cpu() >= 0.5,torch.tensor([1.0]),torch.tensor([0.0])) acc = torch.eq(y_pred_label, y_data).sum().item()/y_data.size(0) print("loss = ",loss.item(), "acc = ",acc) loss = ?0.4232381284236908 acc = ?0.7931488801054019
二、參考文獻
- [1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=7
- [2] https://blog.csdn.net/bit452/article/details/109682078
原文鏈接:https://blog.csdn.net/weixin_43821559/article/details/123314829
相關推薦
- 2021-12-16 .NET中的狀態機庫Stateless的操作流程_實用技巧
- 2022-05-28 Entity?Framework?Core表名映射_實用技巧
- 2023-02-17 react生命周期(類組件/函數組件)操作代碼_React
- 2022-06-25 iOS自定義滑桿效果_IOS
- 2022-08-07 C#中struct與class的區別詳解_C#教程
- 2022-06-16 利用Jetpack?Compose實現繪制五角星效果_Android
- 2022-05-05 基于Redis分布式BitMap的應用分析_Redis
- 2022-03-31 C語言類的雙向鏈表詳解_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支