網站首頁 編程語言 正文
導讀
MNIST
手寫數字數據集作為深度學習入門的數據集是我們經常都會使用到的,包含了0~9共10個數字類別的圖片,每張圖片的大小為28X28,一共包含了60000張
訓練集圖片和10000張
測試集圖片。
使用PaddlePadlle進行手寫數字識別
- 導包
import paddle
from paddle.vision.transforms import Normalize
- 加載MNIST數據集
#數據的歸一化處理
transform = Normalize(mean=[127.5],std=[127.5],data_format="CHW")
#加載MNIST的訓練數據集
train_mnist_dataset = paddle.vision.datasets.MNIST(mode="train",transform=transform)
#加載MNIST的測試數據集
test_mnist_dataset = paddle.vision.datasets.MNIST(mode="test",transform=transform)
- 展示手寫數字的圖片
import numpy as np
import matplotlib.pyplot as plt
#獲取MNIST數據的圖片和對應的標簽
mnist_image,mnist_label = train_mnist_dataset[0][0],train_mnist_dataset[0][1]
#調整MNIST圖片的大小
mnist_image = mnist_image.reshape((28,28))
plt.figure(figsize=(2,2))
plt.imshow(mnist_image,cmap=plt.cm.binary)
- 使用paddlepaddle定義神經網絡模型
這里我們先使用一個比較簡單的3層感知機來構建一個模型,第一層全連接層的輸出是256,第二層全連接層的輸出是128,第三層全連接層的輸出是10,正好對應10個不同的數字類別
class PerceptronMNIST(paddle.nn.Layer):
def __init__(self,in_features,classes_num):
super(PerceptronMNIST,self).__init__()
#將輸出數據的shape由(-1,1,28,28)變為(-1,784)
self.flatten = paddle.nn.Flatten()
#感知機的第一層全連接層
self.fc1 = paddle.nn.Linear(in_features=in_features,out_features=256)
#激活函數
self.relu1 = paddle.nn.ReLU()
#感知機的第二層全連接層
self.fc2 = paddle.nn.Linear(in_features=256,out_features=128)
#激活函數
self.relu2 = paddle.nn.ReLU()
#感知機的第三層全連接層
self.fc3 = paddle.nn.Linear(in_features=128,out_features=classes_num)
def forward(self,x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
return x
- 打印網絡模型的結構
#使用PaddlePaddle封裝模型
model = paddle.Model(PerceptronMNIST(in_features=28*28,classes_num=10))
#輸出網絡的結構
model.summary((-1,1,28,28))
- 定義優化算法和Loss
#配置模型
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),#使用Adam優化算法
paddle.nn.CrossEntropyLoss(),#使用CrossEntropyLoss作為損失函數
paddle.metric.Accuracy())#使用Accuracy計算精度
- 訓練模型
model.fit(train_mnist_dataset,#設置訓練數據
epochs=10, #定義訓練的epochs
batch_size=128, #設置batch_size
verbose=1) #設置日志的輸出格式
- 評估模型
model.evaluate(test_mnist_dataset,verbose=1)
- 模型預測
results = model.predict(test_mnist_dataset)
#獲取概率最大的label
pred_label = np.argsort(results)
print("圖片的預測標簽為:{}".format(pred_label[0][0][-1][0]))
原文鏈接:https://blog.csdn.net/sinat_29957455/article/details/122849841
相關推薦
- 2023-03-01 shell?sleep睡眠命令的具體使用_linux shell
- 2022-06-25 pytorch中permute()函數用法實例詳解_python
- 2023-02-25 go-micro微服務domain層開發示例詳解_Golang
- 2022-06-06 基于VSTS的Xamarin.Android持續集成步驟詳解_Android
- 2022-08-16 python+pytest接口自動化參數關聯_python
- 2022-04-14 Go語言context?test源碼分析詳情_Golang
- 2022-06-06 uniApp、API ‘offCompassChange‘ is not yet implement
- 2022-03-16 部署.NET6項目到IIS_實用技巧
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支