網站首頁 編程語言 正文
本文實例為大家分享了pytorch使用nn.Moudle實現邏輯回歸的具體代碼,供大家參考,具體內容如下
內容
pytorch使用nn.Moudle實現邏輯回歸
問題
loss下降不明顯
解決方法
#源代碼 out的數據接收方式
? ? ?if torch.cuda.is_available():
? ? ? ? ?x_data=Variable(x).cuda()
? ? ? ? ?y_data=Variable(y).cuda()
? ? ?else:
? ? ? ? ?x_data=Variable(x)
? ? ? ? ?y_data=Variable(y)
? ??
? ? out=logistic_model(x_data) ?#根據邏輯回歸模型擬合出的y值
? ? loss=criterion(out.squeeze(),y_data) ?#計算損失函數
#源代碼 out的數據有拼裝數據直接輸入
# ? ? if torch.cuda.is_available():
# ? ? ? ? x_data=Variable(x).cuda()
# ? ? ? ? y_data=Variable(y).cuda()
# ? ? else:
# ? ? ? ? x_data=Variable(x)
# ? ? ? ? y_data=Variable(y)
? ??
? ? out=logistic_model(x_data) ?#根據邏輯回歸模型擬合出的y值
? ? loss=criterion(out.squeeze(),y_data) ?#計算損失函數
? ? print_loss=loss.data.item() ?#得出損失函數值
源代碼
import torch
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
#生成數據
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums, 2)
x0 = torch.normal(mean_value * n_data, 1) + bias ? ? ?# 類別0 數據 shape=(100, 2)
y0 = torch.zeros(sample_nums) ? ? ? ? ? ? ? ? ? ? ? ? # 類別0 標簽 shape=(100, 1)
x1 = torch.normal(-mean_value * n_data, 1) + bias ? ? # 類別1 數據 shape=(100, 2)
y1 = torch.ones(sample_nums) ? ? ? ? ? ? ? ? ? ? ? ? ?# 類別1 標簽 shape=(100, 1)
x_data = torch.cat((x0, x1), 0) ?#按維數0行拼接
y_data = torch.cat((y0, y1), 0)
#畫圖
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')
plt.show()
# 利用torch.nn實現邏輯回歸
class LogisticRegression(nn.Module):
? ? def __init__(self):
? ? ? ? super(LogisticRegression, self).__init__()
? ? ? ? self.lr = nn.Linear(2, 1)
? ? ? ? self.sm = nn.Sigmoid()
? ? def forward(self, x):
? ? ? ? x = self.lr(x)
? ? ? ? x = self.sm(x)
? ? ? ? return x
? ??
logistic_model = LogisticRegression()
# if torch.cuda.is_available():
# ? ? logistic_model.cuda()
#loss函數和優化
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(logistic_model.parameters(), lr=0.01, momentum=0.9)
#開始訓練
#訓練10000次
for epoch in range(10000):
# ? ? if torch.cuda.is_available():
# ? ? ? ? x_data=Variable(x).cuda()
# ? ? ? ? y_data=Variable(y).cuda()
# ? ? else:
# ? ? ? ? x_data=Variable(x)
# ? ? ? ? y_data=Variable(y)
? ??
? ? out=logistic_model(x_data) ?#根據邏輯回歸模型擬合出的y值
? ? loss=criterion(out.squeeze(),y_data) ?#計算損失函數
? ? print_loss=loss.data.item() ?#得出損失函數值
? ? #反向傳播
? ? loss.backward()
? ? optimizer.step()
? ? optimizer.zero_grad()
? ??
? ? mask=out.ge(0.5).float() ?#以0.5為閾值進行分類
? ? correct=(mask==y_data).sum().squeeze() ?#計算正確預測的樣本個數
? ? acc=correct.item()/x_data.size(0) ?#計算精度
? ? #每隔20輪打印一下當前的誤差和精度
? ? if (epoch+1)%100==0:
? ? ? ? print('*'*10)
? ? ? ? print('epoch {}'.format(epoch+1)) ?#誤差
? ? ? ? print('loss is {:.4f}'.format(print_loss))
? ? ? ? print('acc is {:.4f}'.format(acc)) ?#精度
? ? ? ??
? ? ? ??
w0, w1 = logistic_model.lr.weight[0]
w0 = float(w0.item())
w1 = float(w1.item())
b = float(logistic_model.lr.bias.item())
plot_x = np.arange(-7, 7, 0.1)
plot_y = (-w0 * plot_x - b) / w1
plt.xlim(-5, 7)
plt.ylim(-7, 7)
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=logistic_model(x_data)[:,0].cpu().data.numpy(), s=100, lw=0, cmap='RdYlGn')
plt.plot(plot_x, plot_y)
plt.show()
輸出結果
原文鏈接:https://blog.csdn.net/zeroooo000oo/article/details/107885489
相關推薦
- 2023-01-23 C#實現懸浮窗口的方法詳解_C#教程
- 2021-12-06 Gin框架之參數綁定的實現_Golang
- 2022-07-24 C語言超詳細i講解雙向鏈表_C 語言
- 2022-08-05 C語言示例講解for循環的用法_C 語言
- 2022-12-23 iOS之異常與信號使用場景分析_IOS
- 2022-04-12 安裝zsh&oh-my-zsh(沒有root權限)
- 2022-02-25 Oracle工具PL/SQL的基本語法_oracle
- 2023-12-12 TCP通信的實現-優化點對點聊天
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支