網站首頁 編程語言 正文
本文實例為大家分享了Pytorch實現邏輯回歸分類的具體代碼,供大家參考,具體內容如下
1、代碼實現
步驟:
1.獲得數據
2.建立邏輯回歸模型
3.定義損失函數
4.計算損失函數
5.求解梯度
6.梯度更新
7.預測測試集
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
input_size = 784 ?# 輸入到邏輯回歸模型中的輸入大小
num_classes = 10 ?# 分類的類別個數
num_epochs = 10 ?# 迭代次數
batch_size = 50 ?# 批量訓練個數
learning_rate = 0.01 ?# 學習率
# 下載訓練數據和測試數據
train_dataset = dataset.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dataset.MNIST(root='./data',train=False, transform=transforms.ToTensor)
# 使用DataLoader形成批處理文件
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
# 創建邏輯回歸類模型 ?(sigmoid(wx+b))
class LogisticRegression(nn.Module):
? ? def __init__(self,input_size,num_classes):
? ? ? ? super(LogisticRegression,self).__init__()
? ? ? ? self.linear = nn.Linear(input_size,num_classes)
? ? ? ? self.sigmoid = nn.Sigmoid()
? ? def forward(self, x):
? ? ? ? out = self.linear(x)
? ? ? ? out = self.sigmoid(out)
? ? ? ? return out
# 設定模型參數
model = LogisticRegression(input_size, num_classes)
# 定義損失函數,分類任務,使用交叉熵
criterion = nn.CrossEntropyLoss()
# 優化算法,隨機梯度下降,lr為學習率,獲得模型需要更新的參數值
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# 使用訓練數據訓練模型
for epoch in range(num_epochs):
?? ?# 批量數據進行模型訓練
? ? for i, (images, labels) in enumerate(train_loader):
? ? ?? ?# 需要將數據轉換為張量Variable
? ? ? ? images = Variable(images.view(-1, 28*28))
? ? ? ? labels = Variable(labels)
?? ??? ?
?? ??? ?# 梯度更新前需要進行梯度清零
? ? ? ? optimizer.zero_grad()
?? ??? ?# 獲得模型的訓練數據結果
? ? ? ? outputs = model(images)
?? ??? ?
?? ??? ?# 計算損失函數用于計算梯度
? ? ? ? loss = criterion(outputs, labels)
?? ??? ?# 計算梯度
? ? ? ? loss.backward()
?? ?
?? ??? ?# 進行梯度更新
? ? ? ? optimizer.step()
?? ??? ?# 每隔一段時間輸出一個訓練結果
? ? ? ? if (i+1) % 100 == 0:
? ? ? ? ? ? print('Epoch:[%d %d], Step:[%d/%d], Loss: %.4f' % (epoch+1,num_epochs,i+1,len(train_dataset)//batch_size,loss.item()))
# 訓練好的模型預測測試數據集
correct = 0
total = 0
for images, labels in test_loader:
? ? images = Variable(images.view(-1, 28*28)) ?# 形式為(batch_size,28*28)
? ? outputs = model(images)
? ? _,predicts = torch.max(outputs.data,1) ?# _輸出的是最大概率的值,predicts輸出的是最大概率值所在位置,max()函數中的1表示維度,意思是計算某一行的最大值
? ? total += labels.size(0)
? ? correct += (predicts==labels).sum()
print('Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
2、踩過的坑
1.在代碼中下載訓練數據和測試數據的時候,兩段代碼是有區別的:
train_dataset = dataset.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dataset.MNIST(root='./data',train=False, transform=transforms.ToTensor)
第一段代碼中多了一個download=True,這個的作用是,如果為True,則從Internet下載數據集并將其存放在根目錄中。如果數據已經下載,則不會再次下載。
在第二段代碼中沒有加download=True,加了的話在使用測試數據進行預測的時候會報錯。
代碼中transform=transforms.ToTensor()的作用是將PIL圖像轉換為Tensor,同時已經進行歸一化處理。
2.代碼中設置損失函數:
criterion = nn.CrossEntropyLoss()
loss = criterion(outputs, labels)
一開始的時候直接使用:
loss = nn.CrossEntropyLoss()
loss = loss(outputs, labels)
這樣也會報錯,因此需要將loss改為criterion。
原文鏈接:https://blog.csdn.net/qq_37388085/article/details/105737190
相關推薦
- 2022-12-03 Nginx部署SSL證書的過程_nginx
- 2023-03-03 PostgreSQL實時查看數據庫實例正在執行的SQL語句實例詳解_PostgreSQL
- 2022-11-30 Android實現一鍵鎖屏功能_Android
- 2022-03-17 正確使用dotnet-*工具的方法_實用技巧
- 2022-09-10 Python自動打印被調用函數變量名及對應值?_python
- 2022-09-03 python四則運算表達式求值示例詳解_python
- 2022-04-06 詳解pandas中缺失數據處理的函數_python
- 2022-12-06 Python基礎面向對象之繼承與派生詳解_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支