網站首頁 編程語言 正文
1. 導入庫
機器學習的任務分為兩大類:分類和回歸
分類是對一堆目標進行識別歸類,例如貓狗分類、手寫數字分類等等
回歸是對某樣事物接下來行為的預測,例如預測天氣等等
這次我們要完成的任務是邏輯回歸,雖然名字叫做回歸,其實是個二元分類的任務
首先看看我們需要的庫文件
torch.nn 是專門為神經網絡設計的接口
matplotlib 用來繪制圖像,幫助可視化任務
torch 定義張量,數據的傳輸利用張量來實現
optim 優化器的包,例如SGD等
numpy 數據處理的包
2. 定義數據集
簡單說明一下任務,想在一個正方形的區域內生成若干點,然后手工設計label,最后通過神經網絡的訓練,畫出決策邊界
假設:正方形的邊長是2,左下角的坐標為(0,0),右上角的坐標為(2,2)
然后我們手工定義分界線 y = x ,在分界線的上方定義為藍色,下方定義為紅色
2.1 生成數據
首先生成數據的代碼為
首先通過rand(0-1的均勻分布)生成200個點,并將他們擴大2倍,x1代表橫坐標,x2代表縱坐標
然后定義一下分類,這里簡單介紹一下zip函數。
zip會將這里的a,b對應打包成一對,這樣i對應的就是(1,‘a’),i[0] 對應的就是1 2 3
再回到我們的代碼,因為我們要實現的是二元分類,所以我們定義兩個不同的類型,用pos,neg存起來。然后我們知道i[1] 代表的是 x2 ,i[0] 代表的是x1 , 所以 x2 - x1 <0 也就是也就是在直線y=x的下面為pos類型。否則,為neg類型
最后,我們需要將pos,neg類型的繪制出來。因為pos里面其實是類似于(1,1)這樣的坐標,因為pos.append(i) 里面的 i 其實是(x1,x2) 的坐標形式, 所以我們將pos 里面的第一個元素x1定義為賦值給橫坐標,第二個元素x2賦值給縱坐標
然后通過scatter 繪制離散的點就可以,將pos 繪制成 red 顏色,neg 繪制成 blue 顏色,如圖
2.2 設置label
我們進行的其實是有監督學習,所以需要label
這里需要注意的是,不同于回歸任務,x1不是輸入,x2也不是輸出。應該x1,x2都是輸入的元素,也就是特征feature。所以我們應該將紅色的點集設置一個標簽,例如 1 ,藍色的點集設置一個標簽,例如 0.
實現代碼如下
很容易理解,訓練集x_data 應該是所有樣本,也就是pos和neg的所以元素。而之前介紹了x1,x2都是輸入的特征,那么x_data的shape 應該是 [200,2] 的。而y_data 只有1(pos 紅色)類別,或者 0(neg 藍色)類型,所以y_data 的shape 應該是 [200,1] 的。y_data view的原因是變成矩陣的形式而不是向量的形式
這里的意思是,假如坐標是(1.5,0.5)那么應該落在紅色區域,那么這個點的標簽就是1
3. 搭建網絡+優化器
網絡的類型很簡單,不再贅述。至于為什么要繼承nn.Module或者super那步是干啥的不用管,基本上都是這樣寫的,記住就行。
需要注意的是我們輸入的特征是(n * 2) ,所以Linear 應該是(2,1)
二元分類最后的輸出一般選用sigmoid函數
這里的損失函數我們選擇BCE,二元交叉熵損失函數。
算法為隨機梯度下降
4. 訓練
訓練的過程也比較簡單,就是將模型的預測輸出值和真實的label作比較。然后將梯度歸零,在反向傳播并且更新梯度。
5. 繪制決策邊界
這里模型訓練完成后,將w0,w1 ,b取出來,然后繪制出直線
這里要繪制的是w0 * x1+ w1 * x2 + b = 0 ,因為最開始介紹了x1代表橫坐標x,x2代表縱坐標y。通過變形可知y = (-w0 * x1 - b ) / w1,結果如圖
程序輸出的損失為
最后,w0 = 4.1911 , w1 = -4.0290 ,b = 0.0209 ,近似等于y = x,和我們剛開始定義的分界線類似
6. 代碼
import torch.nn as nn
import matplotlib.pyplot as plt
import torch
from torch import optim
import numpy as np
torch.manual_seed(1) # 保證程序隨機生成數一樣
x1 = torch.rand(200) * 2
x2 = torch.rand(200) * 2
data = zip(x1,x2)
pos = [] # 定義類型 1
neg = [] # 定義類型 2
def classification(data):
for i in data:
if(i[1] - i[0] < 0):
pos.append(i)
else:
neg.append(i)
classification(data)
pos_x = [i[0] for i in pos]
pos_y = [i[1] for i in pos]
neg_x = [i[0] for i in neg]
neg_y = [i[1] for i in neg]
plt.scatter(pos_x,pos_y,c='r')
plt.scatter(neg_x,neg_y,c='b')
plt.show()
x_data = [[i[0],i[1]] for i in pos]
x_data.extend([[i[0],i[1]] for i in neg])
x_data = torch.Tensor(x_data) # 輸入數據 feature
y_data = [1 for i in range(len(pos))]
y_data.extend([0 for i in range(len(neg))])
y_data = torch.Tensor(y_data).view(-1,1) # 對應的標簽
class LogisticRegressionModel(nn.Module): # 定義網絡
def __init__(self):
super(LogisticRegressionModel,self).__init__()
self.linear = nn.Linear(2,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.linear(x)
x = self.sigmoid(x)
return x
model = LogisticRegressionModel()
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(),lr =0.01)
for epoch in range(10000):
y_pred = model(x_data)
loss = criterion(y_pred,y_data) # 計算損失值
if epoch % 1000 == 0:
print(epoch,loss.item()) # 打印損失值
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向傳播
optimizer.step() # 梯度更新
w = model.linear.weight[0] # 取出訓練完成的結果
w0 = w[0]
w1 = w[1]
b = model.linear.bias.item()
with torch.no_grad(): # 繪制決策邊界,這里不需要計算梯度
x= torch.arange(0,3).view(-1,1)
y = (- w0 * x - b) / w1
plt.plot(x.numpy(),y.numpy())
plt.scatter(pos_x,pos_y,c='r')
plt.scatter(neg_x,neg_y,c='b')
plt.xlim(0,2)
plt.ylim(0,2)
plt.show()
程序結果
原文鏈接:https://blog.csdn.net/qq_44886601/article/details/127284028
相關推薦
- 2022-06-06 Element UI詳解el-scrollbar 滾動條組件 —— 監聽滾動條的滾動,跟隨頁面一起滾
- 2022-06-28 nginx使用內置模塊配置限速限流的方法實例_nginx
- 2022-10-03 Docker啟動失敗報錯Failed?to?start?Docker?Application?Con
- 2022-03-24 C語言指針的圖文詳解_C 語言
- 2022-07-07 關于C++智能指針shared_ptr和unique_ptr能否互轉問題_C 語言
- 2022-09-24 ASP.NET?MVC下拉框中顯示枚舉項_實用技巧
- 2022-12-27 Swift?Error重構優化詳解_Swift
- 2022-07-18 Linux 文件內容瀏覽;cut命令;uniq命令;sort命令;tr命令;
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支