網(wǎng)站首頁 編程語言 正文
導(dǎo)讀
MNIST
手寫數(shù)字?jǐn)?shù)據(jù)集作為深度學(xué)習(xí)入門的數(shù)據(jù)集是我們經(jīng)常都會(huì)使用到的,包含了0~9共10個(gè)數(shù)字類別的圖片,每張圖片的大小為28X28,一共包含了60000張
訓(xùn)練集圖片和10000張
測試集圖片。
使用PaddlePadlle進(jìn)行手寫數(shù)字識(shí)別
- 導(dǎo)包
import paddle
from paddle.vision.transforms import Normalize
- 加載MNIST數(shù)據(jù)集
#數(shù)據(jù)的歸一化處理
transform = Normalize(mean=[127.5],std=[127.5],data_format="CHW")
#加載MNIST的訓(xùn)練數(shù)據(jù)集
train_mnist_dataset = paddle.vision.datasets.MNIST(mode="train",transform=transform)
#加載MNIST的測試數(shù)據(jù)集
test_mnist_dataset = paddle.vision.datasets.MNIST(mode="test",transform=transform)
- 展示手寫數(shù)字的圖片
import numpy as np
import matplotlib.pyplot as plt
#獲取MNIST數(shù)據(jù)的圖片和對應(yīng)的標(biāo)簽
mnist_image,mnist_label = train_mnist_dataset[0][0],train_mnist_dataset[0][1]
#調(diào)整MNIST圖片的大小
mnist_image = mnist_image.reshape((28,28))
plt.figure(figsize=(2,2))
plt.imshow(mnist_image,cmap=plt.cm.binary)
- 使用paddlepaddle定義神經(jīng)網(wǎng)絡(luò)模型
這里我們先使用一個(gè)比較簡單的3層感知機(jī)來構(gòu)建一個(gè)模型,第一層全連接層的輸出是256,第二層全連接層的輸出是128,第三層全連接層的輸出是10,正好對應(yīng)10個(gè)不同的數(shù)字類別
class PerceptronMNIST(paddle.nn.Layer):
def __init__(self,in_features,classes_num):
super(PerceptronMNIST,self).__init__()
#將輸出數(shù)據(jù)的shape由(-1,1,28,28)變?yōu)?-1,784)
self.flatten = paddle.nn.Flatten()
#感知機(jī)的第一層全連接層
self.fc1 = paddle.nn.Linear(in_features=in_features,out_features=256)
#激活函數(shù)
self.relu1 = paddle.nn.ReLU()
#感知機(jī)的第二層全連接層
self.fc2 = paddle.nn.Linear(in_features=256,out_features=128)
#激活函數(shù)
self.relu2 = paddle.nn.ReLU()
#感知機(jī)的第三層全連接層
self.fc3 = paddle.nn.Linear(in_features=128,out_features=classes_num)
def forward(self,x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
return x
- 打印網(wǎng)絡(luò)模型的結(jié)構(gòu)
#使用PaddlePaddle封裝模型
model = paddle.Model(PerceptronMNIST(in_features=28*28,classes_num=10))
#輸出網(wǎng)絡(luò)的結(jié)構(gòu)
model.summary((-1,1,28,28))
- 定義優(yōu)化算法和Loss
#配置模型
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),#使用Adam優(yōu)化算法
paddle.nn.CrossEntropyLoss(),#使用CrossEntropyLoss作為損失函數(shù)
paddle.metric.Accuracy())#使用Accuracy計(jì)算精度
- 訓(xùn)練模型
model.fit(train_mnist_dataset,#設(shè)置訓(xùn)練數(shù)據(jù)
epochs=10, #定義訓(xùn)練的epochs
batch_size=128, #設(shè)置batch_size
verbose=1) #設(shè)置日志的輸出格式
- 評(píng)估模型
model.evaluate(test_mnist_dataset,verbose=1)
- 模型預(yù)測
results = model.predict(test_mnist_dataset)
#獲取概率最大的label
pred_label = np.argsort(results)
print("圖片的預(yù)測標(biāo)簽為:{}".format(pred_label[0][0][-1][0]))
原文鏈接:https://blog.csdn.net/sinat_29957455/article/details/122849841
相關(guān)推薦
- 2022-05-22 C#中深拷貝和淺拷貝的介紹與用法_C#教程
- 2022-04-01 Python實(shí)現(xiàn)隨機(jī)生成圖片驗(yàn)證碼詳解_python
- 2023-01-19 Oracle查詢表空間大小及每個(gè)表所占空間的大小語句示例_oracle
- 2022-11-14 解讀python正則表達(dá)式括號(hào)問題_python
- 2022-09-16 go?goroutine實(shí)現(xiàn)素?cái)?shù)統(tǒng)計(jì)的示例_Golang
- 2023-10-15 AddressSanitizer 查找內(nèi)存問題
- 2022-12-23 C++?Boost?Conversion超詳細(xì)講解_C 語言
- 2022-04-23 Android自定義View實(shí)現(xiàn)數(shù)字雨效果的全過程_Android
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支